Image-Classifier-TensorFlow / resnet_model.py
vaishanthr's picture
initial commit
31607dc
raw
history blame contribute delete
922 Bytes
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class ResNetClassifier:
def __init__(self):
self.model = keras.applications.ResNet50(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)
def preprocess_image(self, image):
img = keras.preprocessing.image.array_to_img(image)
img = img.resize((224, 224))
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
img_array = keras.applications.resnet50.preprocess_input(img_array)
return img_array
def classify_image(self, image):
# Preprocess the image
img_array = self.preprocess_image(image)
# Classify the image
predictions = self.model.predict(img_array)
predicted_classes = keras.applications.imagenet_utils.decode_predictions(predictions, top=3)[0]
return predicted_classes