File size: 809 Bytes
56744d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import gradio as gr
# Load the model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
# Define prediction function
def classify_image(image):
image = Image.fromarray(image).convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
predicted_class = outputs.logits.argmax(-1).item()
return model.config.id2label[predicted_class]
# Create a Gradio app
app = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="numpy"),
outputs="text",
title="Image Classifier"
)
app.launch() |