imagetxt / app.py
Mohuu0601's picture
Create app.py
56744d1 verified
raw
history blame
809 Bytes
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import gradio as gr
# Load the model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
# Define prediction function
def classify_image(image):
image = Image.fromarray(image).convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
predicted_class = outputs.logits.argmax(-1).item()
return model.config.id2label[predicted_class]
# Create a Gradio app
app = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="numpy"),
outputs="text",
title="Image Classifier"
)
app.launch()