ademibeh commited on
Commit
0fbc8da
·
verified ·
1 Parent(s): eed7a37

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mobilenet_best_model.keras filter=lfs diff=lfs merge=lfs -text
37
+ resnet_best_model.keras filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from tensorflow.keras.preprocessing import image as keras_image
5
+ from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
6
+ from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess_input
7
+ from tensorflow.keras.models import load_model
8
+
9
+ # Load your trained models
10
+ resnet_model = load_model('/path/to/resnet_model.h5') # Update path
11
+ mobilenet_model = load_model('/path/to/mobilenet_model.h5') # Update path
12
+
13
+ def predict_comic_character(img, model_type):
14
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
15
+ img = img.resize((224, 224)) # Resize the image to fit model input
16
+ img_array = keras_image.img_to_array(img)
17
+ img_array = np.expand_dims(img_array, axis=0)
18
+
19
+ if model_type == 'ResNet50':
20
+ img_array = resnet_preprocess_input(img_array)
21
+ prediction = resnet_model.predict(img_array)
22
+ elif model_type == 'MobileNetV2':
23
+ img_array = mobilenet_preprocess_input(img_array)
24
+ prediction = mobilenet_model.predict(img_array)
25
+ else:
26
+ return {"error": "Invalid model type selected"}
27
+
28
+ classes = ['Superman', 'Batman', 'WonderWoman', 'Riddler', 'Spider-Man', 'Iron-Man',
29
+ 'Hulk', 'The Joker', 'Magneto', 'Wolverine', 'Deadpool', 'Catwoman']
30
+
31
+ return {classes[i]: float(prediction[0][i]) for i in range(len(classes))}
32
+
33
+ # Define the Gradio interface
34
+ interface = gr.Interface(
35
+ fn=predict_comic_character,
36
+ inputs=[
37
+ gr.inputs.Image(type="numpy", label="Upload an image of a comic character"),
38
+ gr.inputs.Radio(['ResNet50', 'MobileNetV2'], label="Choose Model")
39
+ ],
40
+ outputs=gr.outputs.Label(num_top_classes=5, label="Prediction"),
41
+ title="Comic Character Classifier",
42
+ description="Upload an image of a comic character and the classifier will predict the character.",
43
+ theme="huggingface" # Optional: Adds a nice theme from Gradio's gallery
44
+ )
45
+
46
+ # Launch the interface
47
+ interface.launch()
mobilenet_best_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5a5a70d275d9a8512fc8552489ded8cc6d8b4bdd52327eba11f23bc1a6d1f37
3
+ size 25466457
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blinker==1.7.0
2
+ click==8.1.7
3
+ Flask==3.0.2
4
+ Flask-Cors==4.0.0
5
+ itsdangerous==2.1.2
6
+ Jinja2==3.1.3
7
+ joblib==1.3.2
8
+ MarkupSafe==2.1.5
9
+ numpy==1.26.4
10
+ pandas==2.2.1
11
+ python-dateutil==2.8.2
12
+ pytz==2024.1
13
+ scikit-learn==1.4.1.post1
14
+ scipy==1.12.0
15
+ six==1.16.0
16
+ threadpoolctl==3.3.0
17
+ tzdata==2024.1
18
+ Werkzeug==3.0.1
19
+ tensorflow==2.16.1
resnet_best_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d209662d425c49e7f4573de16252bd3724b9c319af87bf6717ada837636694f4
3
+ size 120289188