nsfw_detection / app.py
yeftakun's picture
Update app.py
b75787e verified
raw
history blame
2.1 kB
import gradio as gr
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import io
import requests
from flask import Flask, request, jsonify
# Load the model and processor
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
# Define prediction function
def predict_image(image):
try:
# Process the image and make prediction
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
except Exception as e:
return str(e)
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Textbox(label="Predicted Class"),
title="NSFW Image Classifier"
)
# Launch the Gradio interface
iface.launch()
# Flask app for API endpoint
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
try:
# Load image from the uploaded file
image = Image.open(file.stream)
# Process the image and make prediction
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return jsonify({'predicted_class': predicted_label})
except Exception as e:
return jsonify({'error': str(e)}), 500
# Run Flask app
if __name__ == '__main__':
app.run(port=5000)