import gradio as gr from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests # Load the model and processor processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') # Define prediction function def predict_image(image_url): try: # Load image from URL image = Image.open(requests.get(image_url, stream=True).raw) # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Create Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Textbox(label="Image URL"), outputs=gr.Textbox(label="Predicted Class"), title="NSFW Image Classifier" ) # Launch the interface iface.launch()