0-ma's picture
Update app.py
7f73a40 verified
raw
history blame
1.54 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import requests
labels = [
"None",
"Circle",
"Triangle",
"Square",
"Pentagon",
"Hexagon"
]
#images = [Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_circle.jpg", stream=True).raw),
# Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_pentagone.jpg", stream=True).raw)]
feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
print(predicted_labels)
labels = []
def predict(img):
img = PILImage.create(img)
inputs = feature_extractor(images=images, return_tensors="pt")
logits = model(**inputs)['logits'].cpu().detach().numpy()
predictions = np.argmax(logits, axis=1)
predicted_labels = [labels[prediction] for prediction in predictions]
return {"predicted_labels" : predicted_labels , "predictions": predictions}
title = "Geometric Shape Classifier"
description = "A geometric shape setector."
examples = ['A.jpg']
interpretation='default'
enable_queue=True
gr.Interface(fn=predict,inputs=gr.inputs.Image(shape=(512, 512)),outputs=gr.outputs.Label(num_top_classes=3),title=title,description=description,examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch()