import cv2 import tempfile import gradio as gr from ultralytics import YOLO import pandas as pd import plotly.graph_objects as go import numpy as np # Define the label mapping label_mapping = { 0: 'Hymenoptera', 1: 'Mantodea', 2: 'Odonata', 3: 'Orthoptera', 4: 'Coleoptera', 5: 'Lepidoptera', 6: 'Hemiptera' } def process_video(video_file): # Load the YOLOv8 model model = YOLO("insect_detection4.pt") # Open the video file cap = cv2.VideoCapture(video_file) # Prepare DataFrame for storing detection data columns = ["frame", "insect_id", "class", "x", "y", "w", "h"] df = pd.DataFrame(columns=columns) frame_id = 0 unique_insect_crops = {} # Loop through the video frames while cap.isOpened(): # Read a frame from the video success, frame = cap.read() if success: frame_id += 1 # Run YOLOv8 tracking on the frame, persisting tracks between frames results = model.track(frame, persist=True, tracker="insect_tracker.yaml") for result in results: boxes = result.boxes.cpu().numpy() confidences = boxes.conf class_ids = boxes.cls for i, box in enumerate(boxes): class_id = int(class_ids[i]) confidence = confidences[i] insect_id = int(box.id[0]) if box.id is not None else -1 # Use -1 if ID is not available # Append detection data to DataFrame new_row = pd.DataFrame({ "frame": [frame_id], "insect_id": [insect_id], "class": [class_id], "x": [box.xywh[0][0]], "y": [box.xywh[0][1]], "w": [box.xywh[0][2]], "h": [box.xywh[0][3]] }) df = pd.concat([df, new_row], ignore_index=True) # Crop and save the image of the insect if insect_id not in unique_insect_crops: x_center, y_center, width, height = box.xywh[0] x1 = int(x_center - width / 2) y1 = int(y_center - height / 2) x2 = int(x_center + width / 2) y2 = int(y_center + height / 2) insect_crop = frame[y1:y2, x1:x2] crop_path = tempfile.mktemp(suffix=".png") cv2.imwrite(crop_path, insect_crop) unique_insect_crops[insect_id] = (crop_path, label_mapping[class_id]) else: break # Release the video capture cap.release() # Save DataFrame to CSV csv_path = tempfile.mktemp(suffix=".csv") df.to_csv(csv_path, index=False) # Read the DataFrame from the CSV file df_from_csv = pd.read_csv(csv_path) # Create the interactive plot from the CSV data fig = go.Figure() for insect_id, group in df_from_csv.groupby('insect_id'): class_name = label_mapping[group.iloc[0]['class']] color = 'rgb({}, {}, {})'.format(*np.random.randint(0, 256, 3)) hover_text = group.apply(lambda row: f'Insect ID: {int(row["insect_id"])}, Class: {class_name}, Frame: {int(row["frame"])}', axis=1) fig.add_trace(go.Scatter(x=group['frame'], y=group['insect_id'], mode='markers', marker=dict(color=color), name=f'{class_name} {insect_id}', hoverinfo='text', hovertext=hover_text)) fig.update_layout(title='Temporal distribution of insects', xaxis_title='Frame', yaxis_title='Insect ID', hovermode='closest') gallery_items = [(crop_path, f'{label} {insect_id}') for insect_id, (crop_path, label) in unique_insect_crops.items()] return fig, gallery_items, csv_path # Create a Gradio interface example_video = "insect_trap_video_example.mp4" # Replace with the actual path to your example video inputs = gr.Video(label="Input Insect Trap Video", value=example_video) outputs = [ gr.Plot(label="Insect Detection Plot"), gr.Gallery(label="Insect Gallery"), # Added a gallery to display insect crops with labels gr.File(label="Download CSV") ] description = """ Uncover the Secret Lives of Insects of the Amazonian Forest! 🐝🦋🕷️ Upload your video now to track, visualize, and explore insect activity with our cutting-edge detection tool. You can get started with the example video. """ gr.Interface(fn=process_video, inputs=inputs, outputs=outputs, title= 'InsectSpy 🕵️‍♂️🦗', description=description, examples=[example_video]).launch()