import torch from transformers import AutoImageProcessor, AutoModelForObjectDetection from PIL import Image import cv2 import numpy as np import time import gradio as gr # Device setup (GPU or CPU) device = 'cpu' if torch.cuda.is_available(): device = torch.device('cuda') elif torch.backends.mps.is_available(): device = torch.device('mps') # Load pre-trained model and image processor from Hugging Face ckpt = 'yainage90/fashion-object-detection' image_processor = AutoImageProcessor.from_pretrained(ckpt) model = AutoModelForObjectDetection.from_pretrained(ckpt).to(device) def detect_objects(frame): """Detect objects in the video frame.""" # Convert the frame to PIL image image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Prepare inputs for the model with torch.no_grad(): inputs = image_processor(images=[image], return_tensors="pt") outputs = model(**inputs.to(device)) target_sizes = torch.tensor([[image.size[1], image.size[0]]]) results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0] # Extract the detected items items = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): score = score.item() label = label.item() box = [i.item() for i in box] print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}") items.append((score, label, box)) return items def process_image(image): """Process the image uploaded via Gradio and return the result.""" # Convert the image to numpy array frame = np.array(image) # Detect objects (e.g., helmets) in the frame items = detect_objects(frame) # Check if helmet is detected (you can adapt this based on your model's labels) helmet_detected = False for score, label, box in items: if model.config.id2label[label] == "helmet": # Replace "helmet" with the actual class name in your model helmet_detected = True # If no helmet detected, show a traffic violation notification if not helmet_detected: violation_message = "Serious Traffic Violation: Rider not wearing a helmet!" else: violation_message = "Helmet detected: No violation." # Save the image with detected items if items: # If objects are detected, save the data save_data(frame, items) return {"items_detected": items, "violation_message": violation_message} def save_data(frame, items): """Save image and extract plate number.""" filename = f"helmet_violation_{int(time.time())}.jpg" cv2.imwrite(filename, frame) # Here, you'd extract plate numbers or process further plate_number = extract_plate_number(frame) save_to_database(filename, plate_number, items) def extract_plate_number(frame): """Extract license plate number (simplified).""" plate_number = "XYZ 1234" # Replace with an actual license plate recognition method return plate_number def save_to_database(image_filename, plate_number, items): """Save the data (for simplicity, we just print it here).""" print(f"Plate Number: {plate_number}, Image saved as {image_filename}") print("Detected items:", items) # Define the Gradio interface using updated syntax interface = gr.Interface(fn=process_image, inputs=gr.Image(type="pil"), outputs=[gr.JSON(), gr.Textbox()], live=True) # Launch the Gradio app interface.launch(debug=True)