|
import gradio as gr |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from custom_model import ImageClassifier |
|
from resnet_model import ResNetClassifier |
|
from vgg16_model import VGG16Classifier |
|
from inception_v3_model import InceptionV3Classifier |
|
from mobilevet_v2 import MobileNetClassifier |
|
import os |
|
|
|
CLASS_NAMES =['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'] |
|
|
|
|
|
custom_model = ImageClassifier() |
|
custom_model.load_model("image_classifier_model.h5") |
|
resnet_model = ResNetClassifier() |
|
vgg16_model = VGG16Classifier() |
|
inceptionV3_model = InceptionV3Classifier() |
|
mobilenet_model = MobileNetClassifier() |
|
|
|
def make_prediction(image, model_type="CNN (Custom)"): |
|
if "CNN (Custom)" == model_type: |
|
top_classes, top_probs = custom_model.classify_image(image, top_k=3) |
|
return {CLASS_NAMES[cls_id]:str(prob) for cls_id, prob in zip(top_classes, top_probs)} |
|
elif "ResNet50" == model_type: |
|
predictions = resnet_model.classify_image(image) |
|
return {class_name:str(prob) for _, class_name, prob in predictions} |
|
elif "VGG16" == model_type: |
|
predictions = vgg16_model.classify_image(image) |
|
return {class_name:str(prob) for _, class_name, prob in predictions} |
|
elif "Inception v3" == model_type: |
|
predictions = inceptionV3_model.classify_image(image) |
|
return {class_name:str(prob) for _, class_name, prob in predictions} |
|
elif "Mobile Net v2" == model_type: |
|
predictions = mobilenet_model.classify_image(image) |
|
return {class_name:str(prob) for _, class_name, prob in predictions} |
|
else: |
|
return {"Select a model to classify image"} |
|
|
|
def train_model(epochs, batch_size, validation_split): |
|
|
|
print("Training model") |
|
|
|
|
|
classifier = ImageClassifier() |
|
|
|
|
|
(x_train, y_train), (x_test, y_test) = classifier.load_dataset() |
|
|
|
|
|
classifier.build_model(x_train) |
|
classifier.train_model(x_train, y_train, batch_size=int(batch_size), epochs=int(epochs), validation_split=float(validation_split)) |
|
|
|
|
|
classifier.evaluate_model(x_test, y_test) |
|
|
|
|
|
print("Saving model ...") |
|
classifier.save_model("image_classifier_model.h5") |
|
|
|
custom_model = classifier |
|
|
|
|
|
def update_train_param_display(model_type): |
|
if "CNN (Custom)" == model_type: |
|
return [gr.update(visible=True), gr.update(visible=False)] |
|
return [gr.update(visible=False), gr.update(visible=True)] |
|
|
|
if __name__ == "__main__": |
|
|
|
with gr.Blocks() as my_app: |
|
gr.Markdown("<h1><center>Image Classification using TensorFlow</center></h1>") |
|
gr.Markdown("<h3><center>This model classifies image using different models.</center></h3>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
img_input = gr.Image() |
|
model_type = gr.Dropdown( |
|
["CNN (Custom)", |
|
"ResNet50", |
|
"VGG16", |
|
"Inception v3", |
|
"Mobile Net v2"], |
|
label="Model Type", value="CNN (Custom)", |
|
info="Select the inference model before running predictions!") |
|
|
|
with gr.Column() as train_col: |
|
gr.Markdown("Train Parameters") |
|
with gr.Row(): |
|
epochs_inp = gr.Textbox(label="Epochs", value="10") |
|
validation_split = gr.Textbox(label="Validation Split", value="0.1") |
|
|
|
with gr.Row(): |
|
batch_size = gr.Textbox(label="Batch Size", value="64") |
|
|
|
with gr.Row(): |
|
train_btn = gr.Button(value="Train") |
|
predict_btn_1 = gr.Button(value="Predict") |
|
|
|
with gr.Column(visible=False) as no_train_col: |
|
predict_btn_2 = gr.Button(value="Predict") |
|
|
|
with gr.Column(scale=1): |
|
output_label = gr.Label() |
|
|
|
gr.Markdown("## Sample Images") |
|
gr.Examples( |
|
examples=[os.path.join(os.path.dirname(__file__), "assets/dog_2.jpg"), |
|
os.path.join(os.path.dirname(__file__), "assets/truck.jpg"), |
|
os.path.join(os.path.dirname(__file__), "assets/car.jpg"), |
|
os.path.join(os.path.dirname(__file__), "assets/car_32x32.jpg") |
|
], |
|
inputs=img_input, |
|
outputs=output_label, |
|
fn=make_prediction, |
|
cache_examples=True, |
|
) |
|
|
|
|
|
|
|
|
|
predict_btn_1.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label]) |
|
predict_btn_2.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label]) |
|
model_type.change(update_train_param_display, inputs=model_type, outputs=[train_col, no_train_col]) |
|
train_btn.click(train_model, inputs=[epochs_inp, batch_size, validation_split], outputs=[]) |
|
|
|
my_app.queue(concurrency_count=5, max_size=20).launch(debug=True) |