import gradio as gr import tensorflow as tf from tensorflow import keras from custom_model import ImageClassifier from resnet_model import ResNetClassifier from vgg16_model import VGG16Classifier from inception_v3_model import InceptionV3Classifier from mobilevet_v2 import MobileNetClassifier import os CLASS_NAMES =['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'] # models custom_model = ImageClassifier() custom_model.load_model("image_classifier_model.h5") resnet_model = ResNetClassifier() vgg16_model = VGG16Classifier() inceptionV3_model = InceptionV3Classifier() mobilenet_model = MobileNetClassifier() def make_prediction(image, model_type="CNN (Custom)"): if "CNN (Custom)" == model_type: top_classes, top_probs = custom_model.classify_image(image, top_k=3) return {CLASS_NAMES[cls_id]:str(prob) for cls_id, prob in zip(top_classes, top_probs)} elif "ResNet50" == model_type: predictions = resnet_model.classify_image(image) return {class_name:str(prob) for _, class_name, prob in predictions} elif "VGG16" == model_type: predictions = vgg16_model.classify_image(image) return {class_name:str(prob) for _, class_name, prob in predictions} elif "Inception v3" == model_type: predictions = inceptionV3_model.classify_image(image) return {class_name:str(prob) for _, class_name, prob in predictions} elif "Mobile Net v2" == model_type: predictions = mobilenet_model.classify_image(image) return {class_name:str(prob) for _, class_name, prob in predictions} else: return {"Select a model to classify image"} def train_model(epochs, batch_size, validation_split): print("Training model") # Create an instance of the ImageClassifier classifier = ImageClassifier() # Load the dataset (x_train, y_train), (x_test, y_test) = classifier.load_dataset() # Build and train the model classifier.build_model(x_train) classifier.train_model(x_train, y_train, batch_size=int(batch_size), epochs=int(epochs), validation_split=float(validation_split)) # Evaluate the model classifier.evaluate_model(x_test, y_test) # Save the trained model print("Saving model ...") classifier.save_model("image_classifier_model.h5") custom_model = classifier def update_train_param_display(model_type): if "CNN (Custom)" == model_type: return [gr.update(visible=True), gr.update(visible=False)] return [gr.update(visible=False), gr.update(visible=True)] if __name__ == "__main__": # gradio gui app with gr.Blocks() as my_app: gr.Markdown("

Image Classification using TensorFlow

") gr.Markdown("

This model classifies image using different models.

") with gr.Row(): with gr.Column(scale=1): img_input = gr.Image() model_type = gr.Dropdown( ["CNN (Custom)", "ResNet50", "VGG16", "Inception v3", "Mobile Net v2"], label="Model Type", value="CNN (Custom)", info="Select the inference model before running predictions!") with gr.Column() as train_col: gr.Markdown("Train Parameters") with gr.Row(): epochs_inp = gr.Textbox(label="Epochs", value="10") validation_split = gr.Textbox(label="Validation Split", value="0.1") with gr.Row(): batch_size = gr.Textbox(label="Batch Size", value="64") with gr.Row(): train_btn = gr.Button(value="Train") predict_btn_1 = gr.Button(value="Predict") with gr.Column(visible=False) as no_train_col: predict_btn_2 = gr.Button(value="Predict") with gr.Column(scale=1): output_label = gr.Label() gr.Markdown("## Sample Images") gr.Examples( examples=[os.path.join(os.path.dirname(__file__), "assets/dog_2.jpg"), os.path.join(os.path.dirname(__file__), "assets/truck.jpg"), os.path.join(os.path.dirname(__file__), "assets/car.jpg"), os.path.join(os.path.dirname(__file__), "assets/car_32x32.jpg") ], inputs=img_input, outputs=output_label, fn=make_prediction, cache_examples=True, ) # app logic predict_btn_1.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label]) predict_btn_2.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label]) model_type.change(update_train_param_display, inputs=model_type, outputs=[train_col, no_train_col]) train_btn.click(train_model, inputs=[epochs_inp, batch_size, validation_split], outputs=[]) my_app.queue(concurrency_count=5, max_size=20).launch(debug=True)