import gradio as gr from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import io import requests from flask import Flask, request, jsonify # Load the model and processor processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') # Define prediction function def predict_image(image): try: # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Create Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Textbox(label="Predicted Class"), title="NSFW Image Classifier" ) # Launch the Gradio interface iface.launch() # Flask app for API endpoint app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No selected file'}), 400 try: # Load image from the uploaded file image = Image.open(file.stream) # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return jsonify({'predicted_class': predicted_label}) except Exception as e: return jsonify({'error': str(e)}), 500 # Run Flask app if __name__ == '__main__': app.run(port=5000)