ElodieA commited on
Commit
6a0ef93
·
verified ·
1 Parent(s): 0351b19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -41
app.py CHANGED
@@ -1,58 +1,109 @@
1
- import gradio as gr
2
  import cv2
 
3
  import tempfile
4
- import numpy as np
5
  from ultralytics import YOLO
6
 
7
- # Load the YOLOv8 model
8
- model = YOLO('yolov8m.pt') # Ensure you have the correct model path
9
-
10
  def process_video(video_file):
11
- # Create a temporary directory to store processed frames
12
- temp_dir = tempfile.TemporaryDirectory()
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Open the video file
15
- cap = cv2.VideoCapture(video_file.name)
16
-
17
- # Get video properties
18
- fps = int(cap.get(cv2.CAP_PROP_FPS))
19
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
20
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
21
- codec = cv2.VideoWriter_fourcc(*'mp4v')
22
-
23
- # Create a VideoWriter object to save the processed video
24
- output_path = f"{temp_dir.name}/output.mp4"
25
- out = cv2.VideoWriter(output_path, codec, fps, (width, height))
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  while cap.isOpened():
28
- ret, frame = cap.read()
29
- if not ret:
30
- break
 
 
 
 
 
31
 
32
- # Use YOLO model to detect objects in the frame
33
- results = model(frame)
 
 
34
 
35
- # Draw bounding boxes and labels on the frame
36
- annotated_frame = results[0].plot()
 
37
 
38
- # Write the frame to the output video
39
- out.write(annotated_frame)
40
-
41
- # Release the VideoCapture and VideoWriter objects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  cap.release()
 
 
 
 
 
 
 
 
43
  out.release()
44
 
45
- return output_path
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Define the Gradio interface
48
- iface = gr.Interface(
49
- fn=process_video,
50
- inputs=gr.Video(),
51
- outputs=gr.Video(),
52
- title="YOLOv8 Video Object Detection",
53
- description="Upload a video and apply YOLOv8 object detection."
54
- )
55
 
56
- # Launch the Gradio app
57
  if __name__ == "__main__":
58
- iface.launch()
 
 
1
  import cv2
2
+ import csv
3
  import tempfile
4
+ import gradio as gr
5
  from ultralytics import YOLO
6
 
 
 
 
7
  def process_video(video_file):
8
+ # Define colors for each class (8 classes)
9
+ colors = [
10
+ (255, 0, 0), # Class 0 - Blue
11
+ (50, 205, 50), # Class 1 - Green
12
+ (0, 0, 255), # Class 2 - Red
13
+ (255, 255, 0), # Class 3 - Cyan
14
+ (255, 0, 255), # Class 4 - Magenta
15
+ (255, 140, 0), # Class 5 - Orange
16
+ (128, 0, 128), # Class 6 - Purple
17
+ (0, 128, 128) # Class 7 - Teal
18
+ ]
19
+
20
+ # Define class names (example names, replace with actual class names if available)
21
+ class_names = ['Hymenoptera', 'Mantodea', 'Odonata', 'Orthoptera', 'Coleoptera', 'Lepidoptera', 'Hemiptera']
22
+
23
+ # Load the YOLOv8 model
24
+ model = YOLO("insect_detection4.pt")
25
+
26
  # Open the video file
27
+ cap = cv2.VideoCapture(video_file)
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Prepare CSV file for writing
30
+ csv_file = tempfile.NamedTemporaryFile(mode='w', delete=False)
31
+ writer = csv.writer(csv_file)
32
+ writer.writerow(["frame", "id", "class", "x", "y", "w", "h"])
33
+
34
+ frame_id = 0
35
+
36
+ # Initialize a list to store annotated frames
37
+ annotated_frames = []
38
+
39
+ # Loop through the video frames
40
  while cap.isOpened():
41
+ # Read a frame from the video
42
+ success, frame = cap.read()
43
+
44
+ if success:
45
+ frame_id += 1
46
+
47
+ # Run YOLOv8 tracking on the frame, persisting tracks between frames
48
+ results = model.track(frame, persist=True)
49
 
50
+ for result in results:
51
+ boxes = result.boxes.cpu().numpy()
52
+ confidences = boxes.conf
53
+ class_ids = boxes.cls
54
 
55
+ for i, box in enumerate(boxes):
56
+ class_id = int(class_ids[i])
57
+ confidence = confidences[i]
58
 
59
+ color = colors[class_id % len(colors)] # Use the color corresponding to the class
60
+ label = f'{class_names[class_id]}: {confidence:.2f}'
61
+
62
+ # Draw the rectangle
63
+ cv2.rectangle(frame, (int(box.xyxy[0][0]), int(box.xyxy[0][1])),
64
+ (int(box.xyxy[0][2]), int(box.xyxy[0][3])), color, 2)
65
+
66
+ # Display the label above the rectangle
67
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 2.0, 2)
68
+ label_y = max(int(box.xyxy[0][1]) - label_size[1], 0)
69
+ cv2.rectangle(frame, (int(box.xyxy[0][0]), label_y - label_size[1]),
70
+ (int(box.xyxy[0][0]) + label_size[0], label_y + label_size[1]), color, -1)
71
+ cv2.putText(frame, label, (int(box.xyxy[0][0]), label_y), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 255), 2)
72
+
73
+ # Write detection data to CSV
74
+ writer.writerow([frame_id, box.id, int(box.cls[0]), box.xywh[0][0], box.xywh[0][1],
75
+ box.xywh[0][2], box.xywh[0][3]])
76
+
77
+ # Add annotated frame to the list
78
+ annotated_frames.append(frame)
79
+
80
+ else:
81
+ break
82
+
83
+ # Release the video capture
84
  cap.release()
85
+
86
+ # Compile annotated frames into a video
87
+ output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
88
+ height, width, _ = annotated_frames[0].shape
89
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
90
+ out = cv2.VideoWriter(output_video_path, fourcc, 30.0, (width, height))
91
+ for frame in annotated_frames:
92
+ out.write(frame)
93
  out.release()
94
 
95
+ # Close CSV file
96
+ csv_file.close()
97
+
98
+ return output_video_path, csv_file.name
99
+
100
+ # Create a Gradio interface
101
+ inputs = gr.Video(label="Input Video")
102
+ outputs = [gr.Video(label="Annotated Video"), gr.File(label="CSV File")]
103
+
104
+ gradio_app=gr.Interface(fn=process_video, inputs=inputs, outputs=outputs)
105
+
106
 
 
 
 
 
 
 
 
 
107
 
 
108
  if __name__ == "__main__":
109
+ gradio_app.launch()