InsectSpy / app.py
ElodieA's picture
Update app.py
6a0ef93 verified
raw
history blame
3.85 kB
import cv2
import csv
import tempfile
import gradio as gr
from ultralytics import YOLO
def process_video(video_file):
# Define colors for each class (8 classes)
colors = [
(255, 0, 0), # Class 0 - Blue
(50, 205, 50), # Class 1 - Green
(0, 0, 255), # Class 2 - Red
(255, 255, 0), # Class 3 - Cyan
(255, 0, 255), # Class 4 - Magenta
(255, 140, 0), # Class 5 - Orange
(128, 0, 128), # Class 6 - Purple
(0, 128, 128) # Class 7 - Teal
]
# Define class names (example names, replace with actual class names if available)
class_names = ['Hymenoptera', 'Mantodea', 'Odonata', 'Orthoptera', 'Coleoptera', 'Lepidoptera', 'Hemiptera']
# Load the YOLOv8 model
model = YOLO("insect_detection4.pt")
# Open the video file
cap = cv2.VideoCapture(video_file)
# Prepare CSV file for writing
csv_file = tempfile.NamedTemporaryFile(mode='w', delete=False)
writer = csv.writer(csv_file)
writer.writerow(["frame", "id", "class", "x", "y", "w", "h"])
frame_id = 0
# Initialize a list to store annotated frames
annotated_frames = []
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
frame_id += 1
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True)
for result in results:
boxes = result.boxes.cpu().numpy()
confidences = boxes.conf
class_ids = boxes.cls
for i, box in enumerate(boxes):
class_id = int(class_ids[i])
confidence = confidences[i]
color = colors[class_id % len(colors)] # Use the color corresponding to the class
label = f'{class_names[class_id]}: {confidence:.2f}'
# Draw the rectangle
cv2.rectangle(frame, (int(box.xyxy[0][0]), int(box.xyxy[0][1])),
(int(box.xyxy[0][2]), int(box.xyxy[0][3])), color, 2)
# Display the label above the rectangle
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 2.0, 2)
label_y = max(int(box.xyxy[0][1]) - label_size[1], 0)
cv2.rectangle(frame, (int(box.xyxy[0][0]), label_y - label_size[1]),
(int(box.xyxy[0][0]) + label_size[0], label_y + label_size[1]), color, -1)
cv2.putText(frame, label, (int(box.xyxy[0][0]), label_y), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 255), 2)
# Write detection data to CSV
writer.writerow([frame_id, box.id, int(box.cls[0]), box.xywh[0][0], box.xywh[0][1],
box.xywh[0][2], box.xywh[0][3]])
# Add annotated frame to the list
annotated_frames.append(frame)
else:
break
# Release the video capture
cap.release()
# Compile annotated frames into a video
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
height, width, _ = annotated_frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, 30.0, (width, height))
for frame in annotated_frames:
out.write(frame)
out.release()
# Close CSV file
csv_file.close()
return output_video_path, csv_file.name
# Create a Gradio interface
inputs = gr.Video(label="Input Video")
outputs = [gr.Video(label="Annotated Video"), gr.File(label="CSV File")]
gradio_app=gr.Interface(fn=process_video, inputs=inputs, outputs=outputs)
if __name__ == "__main__":
gradio_app.launch()