File size: 4,873 Bytes
31607dc
 
 
 
 
 
 
 
b8b4db9
31607dc
 
 
 
 
 
 
 
 
 
 
4756ba0
dfa7438
31607dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfa7438
31607dc
 
 
 
 
 
 
 
 
 
 
 
 
dfa7438
31607dc
 
 
 
a29d2c5
31607dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfa7438
 
 
 
 
6e3d375
 
dfa7438
4756ba0
 
9ae02b2
dfa7438
 
4756ba0
 
31607dc
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from custom_model import ImageClassifier
from resnet_model import ResNetClassifier
from vgg16_model import VGG16Classifier
from inception_v3_model import InceptionV3Classifier
from mobilevet_v2 import MobileNetClassifier
import os

CLASS_NAMES =['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

# models
custom_model = ImageClassifier()
custom_model.load_model("image_classifier_model.h5")
resnet_model = ResNetClassifier()
vgg16_model = VGG16Classifier()
inceptionV3_model = InceptionV3Classifier()
mobilenet_model = MobileNetClassifier()

def make_prediction(image, model_type="CNN (Custom)"):
  if "CNN (Custom)" == model_type:
    top_classes, top_probs = custom_model.classify_image(image, top_k=3)
    return {CLASS_NAMES[cls_id]:str(prob) for cls_id, prob in zip(top_classes, top_probs)}
  elif "ResNet50" == model_type:
    predictions = resnet_model.classify_image(image)
    return {class_name:str(prob) for _, class_name, prob in predictions}
  elif "VGG16" == model_type:
    predictions = vgg16_model.classify_image(image)
    return {class_name:str(prob) for _, class_name, prob in predictions}
  elif "Inception v3" == model_type:
    predictions = inceptionV3_model.classify_image(image)
    return {class_name:str(prob) for _, class_name, prob in predictions}
  elif "Mobile Net v2" == model_type:
    predictions = mobilenet_model.classify_image(image)
    return {class_name:str(prob) for _, class_name, prob in predictions}
  else:
    return {"Select a model to classify image"}

def train_model(epochs, batch_size, validation_split):

  print("Training model")

  # Create an instance of the ImageClassifier
  classifier = ImageClassifier()

  # Load the dataset
  (x_train, y_train), (x_test, y_test) = classifier.load_dataset()

  # Build and train the model
  classifier.build_model(x_train)
  classifier.train_model(x_train, y_train, batch_size=int(batch_size), epochs=int(epochs), validation_split=float(validation_split))

  # Evaluate the model
  classifier.evaluate_model(x_test, y_test)

  # Save the trained model
  print("Saving model ...")
  classifier.save_model("image_classifier_model.h5")

  custom_model = classifier


def update_train_param_display(model_type):
  if "CNN (Custom)" == model_type:
    return [gr.update(visible=True), gr.update(visible=False)]
  return [gr.update(visible=False), gr.update(visible=True)]

if __name__ == "__main__":
  # gradio gui app
  with gr.Blocks() as my_app:
    gr.Markdown("<h1><center>Image Classification using TensorFlow</center></h1>")
    gr.Markdown("<h3><center>This model classifies image using different models.</center></h3>")

    with gr.Row():
      with gr.Column(scale=1):
          img_input = gr.Image()
          model_type = gr.Dropdown(
              ["CNN (Custom)", 
                "ResNet50", 
                "VGG16",
                "Inception v3",
                "Mobile Net v2"], 
              label="Model Type", value="CNN (Custom)",
              info="Select the inference model before running predictions!")
                   
          with gr.Column() as train_col:
            gr.Markdown("Train Parameters")
            with gr.Row():
              epochs_inp = gr.Textbox(label="Epochs", value="10")
              validation_split = gr.Textbox(label="Validation Split", value="0.1")
            
            with gr.Row():
              batch_size = gr.Textbox(label="Batch Size", value="64")

            with gr.Row():
              train_btn = gr.Button(value="Train")  
              predict_btn_1 = gr.Button(value="Predict")   

          with gr.Column(visible=False) as no_train_col:
            predict_btn_2 = gr.Button(value="Predict")                  

      with gr.Column(scale=1):
        output_label = gr.Label()

    gr.Markdown("## Sample Images")
    gr.Examples(
        examples=[os.path.join(os.path.dirname(__file__), "assets/dog_2.jpg"),
                  os.path.join(os.path.dirname(__file__), "assets/truck.jpg"),
                  os.path.join(os.path.dirname(__file__), "assets/car.jpg"),
                  os.path.join(os.path.dirname(__file__), "assets/car_32x32.jpg")
                 ],
        inputs=img_input,
        outputs=output_label,
        fn=make_prediction,
        cache_examples=True,
    )

            
    
    # app logic
    predict_btn_1.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label])
    predict_btn_2.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label])
    model_type.change(update_train_param_display, inputs=model_type, outputs=[train_col, no_train_col])
    train_btn.click(train_model, inputs=[epochs_inp, batch_size, validation_split], outputs=[])

my_app.queue(concurrency_count=5, max_size=20).launch(debug=True)