vaishanthr's picture
Update app.py
6e3d375
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from custom_model import ImageClassifier
from resnet_model import ResNetClassifier
from vgg16_model import VGG16Classifier
from inception_v3_model import InceptionV3Classifier
from mobilevet_v2 import MobileNetClassifier
import os
CLASS_NAMES =['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
# models
custom_model = ImageClassifier()
custom_model.load_model("image_classifier_model.h5")
resnet_model = ResNetClassifier()
vgg16_model = VGG16Classifier()
inceptionV3_model = InceptionV3Classifier()
mobilenet_model = MobileNetClassifier()
def make_prediction(image, model_type="CNN (Custom)"):
if "CNN (Custom)" == model_type:
top_classes, top_probs = custom_model.classify_image(image, top_k=3)
return {CLASS_NAMES[cls_id]:str(prob) for cls_id, prob in zip(top_classes, top_probs)}
elif "ResNet50" == model_type:
predictions = resnet_model.classify_image(image)
return {class_name:str(prob) for _, class_name, prob in predictions}
elif "VGG16" == model_type:
predictions = vgg16_model.classify_image(image)
return {class_name:str(prob) for _, class_name, prob in predictions}
elif "Inception v3" == model_type:
predictions = inceptionV3_model.classify_image(image)
return {class_name:str(prob) for _, class_name, prob in predictions}
elif "Mobile Net v2" == model_type:
predictions = mobilenet_model.classify_image(image)
return {class_name:str(prob) for _, class_name, prob in predictions}
else:
return {"Select a model to classify image"}
def train_model(epochs, batch_size, validation_split):
print("Training model")
# Create an instance of the ImageClassifier
classifier = ImageClassifier()
# Load the dataset
(x_train, y_train), (x_test, y_test) = classifier.load_dataset()
# Build and train the model
classifier.build_model(x_train)
classifier.train_model(x_train, y_train, batch_size=int(batch_size), epochs=int(epochs), validation_split=float(validation_split))
# Evaluate the model
classifier.evaluate_model(x_test, y_test)
# Save the trained model
print("Saving model ...")
classifier.save_model("image_classifier_model.h5")
custom_model = classifier
def update_train_param_display(model_type):
if "CNN (Custom)" == model_type:
return [gr.update(visible=True), gr.update(visible=False)]
return [gr.update(visible=False), gr.update(visible=True)]
if __name__ == "__main__":
# gradio gui app
with gr.Blocks() as my_app:
gr.Markdown("<h1><center>Image Classification using TensorFlow</center></h1>")
gr.Markdown("<h3><center>This model classifies image using different models.</center></h3>")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image()
model_type = gr.Dropdown(
["CNN (Custom)",
"ResNet50",
"VGG16",
"Inception v3",
"Mobile Net v2"],
label="Model Type", value="CNN (Custom)",
info="Select the inference model before running predictions!")
with gr.Column() as train_col:
gr.Markdown("Train Parameters")
with gr.Row():
epochs_inp = gr.Textbox(label="Epochs", value="10")
validation_split = gr.Textbox(label="Validation Split", value="0.1")
with gr.Row():
batch_size = gr.Textbox(label="Batch Size", value="64")
with gr.Row():
train_btn = gr.Button(value="Train")
predict_btn_1 = gr.Button(value="Predict")
with gr.Column(visible=False) as no_train_col:
predict_btn_2 = gr.Button(value="Predict")
with gr.Column(scale=1):
output_label = gr.Label()
gr.Markdown("## Sample Images")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/dog_2.jpg"),
os.path.join(os.path.dirname(__file__), "assets/truck.jpg"),
os.path.join(os.path.dirname(__file__), "assets/car.jpg"),
os.path.join(os.path.dirname(__file__), "assets/car_32x32.jpg")
],
inputs=img_input,
outputs=output_label,
fn=make_prediction,
cache_examples=True,
)
# app logic
predict_btn_1.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label])
predict_btn_2.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label])
model_type.change(update_train_param_display, inputs=model_type, outputs=[train_col, no_train_col])
train_btn.click(train_model, inputs=[epochs_inp, batch_size, validation_split], outputs=[])
my_app.queue(concurrency_count=5, max_size=20).launch(debug=True)