er1t0 commited on
Commit
8870220
·
1 Parent(s): 9b87d5a

torch autocast

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +55 -46
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ temp_frames
2
+ temp_frames_30
3
+ segmented_video.mp4
app.py CHANGED
@@ -41,6 +41,8 @@ florence_model = load_model_without_flash_attn(load_florence_model)
41
  florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
42
 
43
 
 
 
44
  def apply_color_mask(frame, mask, obj_id):
45
  cmap = plt.get_cmap("tab10")
46
  color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
@@ -61,25 +63,26 @@ def apply_color_mask(frame, mask, obj_id):
61
  colored_mask = mask * color
62
  return frame * (1 - mask) + colored_mask * 255
63
 
 
 
64
  def run_florence(image, text_input):
65
- with torch.amp.autocast(dtype=torch.bfloat16):
66
- task_prompt = '<OPEN_VOCABULARY_DETECTION>'
67
- prompt = task_prompt + text_input
68
- inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
69
- generated_ids = florence_model.generate(
70
- input_ids=inputs["input_ids"].cuda(),
71
- pixel_values=inputs["pixel_values"].cuda(),
72
- max_new_tokens=1024,
73
- early_stopping=False,
74
- do_sample=False,
75
- num_beams=3,
76
- )
77
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
78
- parsed_answer = florence_processor.post_process_generation(
79
- generated_text,
80
- task=task_prompt,
81
- image_size=(image.width, image.height)
82
- )
83
  return parsed_answer[task_prompt]['bboxes'][0]
84
 
85
  def remove_directory_contents(directory):
@@ -89,7 +92,8 @@ def remove_directory_contents(directory):
89
  for name in dirs:
90
  os.rmdir(os.path.join(root, name))
91
 
92
-
 
93
  def process_video(video_path, prompt):
94
  try:
95
  # Get video info
@@ -123,14 +127,13 @@ def process_video(video_path, prompt):
123
  print("Reshaped mask box:", mask_box)
124
 
125
  # SAM2 segmentation on first frame
126
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
127
- image_predictor.set_image(first_frame)
128
- masks, _, _ = image_predictor.predict(
129
- point_coords=None,
130
- point_labels=None,
131
- box=mask_box[None, :],
132
- multimask_output=False,
133
- )
134
  print("masks.shape", masks.shape)
135
 
136
  mask = masks.squeeze().astype(bool)
@@ -145,21 +148,20 @@ def process_video(video_path, prompt):
145
 
146
  print(f"Saved {len(frames)} temporary frames")
147
 
148
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
149
- inference_state = video_predictor.init_state(video_path=temp_dir)
150
- _, _, _ = video_predictor.add_new_mask(
151
- inference_state=inference_state,
152
- frame_idx=0,
153
- obj_id=1,
154
- mask=mask
155
- )
156
-
157
- video_segments = {}
158
- for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
159
- video_segments[out_frame_idx] = {
160
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
161
- for i, out_obj_id in enumerate(out_obj_ids)
162
- }
163
 
164
  print('Segmenting for main vid done')
165
  print(f"Number of segmented frames: {len(video_segments)}")
@@ -216,12 +218,19 @@ def segment_video(video_file, prompt):
216
  demo = gr.Interface(
217
  fn=segment_video,
218
  inputs=[
219
- gr.Video(label="Upload Video"),
220
- gr.Textbox(label="Enter prompt (e.g., 'a gymnast')")
221
  ],
222
  outputs=gr.Video(label="Segmented Video"),
223
- title="Video Object Segmentation with Florence and SAM2",
224
- description="Upload a video and provide a text prompt to segment a specific object throughout the video."
 
 
 
 
 
 
 
225
  )
226
 
227
  demo.launch()
 
41
  florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
42
 
43
 
44
+
45
+
46
  def apply_color_mask(frame, mask, obj_id):
47
  cmap = plt.get_cmap("tab10")
48
  color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
 
63
  colored_mask = mask * color
64
  return frame * (1 - mask) + colored_mask * 255
65
 
66
+ @torch.inference_mode()
67
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
68
  def run_florence(image, text_input):
69
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
70
+ prompt = task_prompt + text_input
71
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
72
+ generated_ids = florence_model.generate(
73
+ input_ids=inputs["input_ids"].cuda(),
74
+ pixel_values=inputs["pixel_values"].cuda(),
75
+ max_new_tokens=1024,
76
+ early_stopping=False,
77
+ do_sample=False,
78
+ num_beams=3,
79
+ )
80
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
81
+ parsed_answer = florence_processor.post_process_generation(
82
+ generated_text,
83
+ task=task_prompt,
84
+ image_size=(image.width, image.height)
85
+ )
 
86
  return parsed_answer[task_prompt]['bboxes'][0]
87
 
88
  def remove_directory_contents(directory):
 
92
  for name in dirs:
93
  os.rmdir(os.path.join(root, name))
94
 
95
+ @torch.inference_mode()
96
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
97
  def process_video(video_path, prompt):
98
  try:
99
  # Get video info
 
127
  print("Reshaped mask box:", mask_box)
128
 
129
  # SAM2 segmentation on first frame
130
+ image_predictor.set_image(first_frame)
131
+ masks, _, _ = image_predictor.predict(
132
+ point_coords=None,
133
+ point_labels=None,
134
+ box=mask_box[None, :],
135
+ multimask_output=False,
136
+ )
 
137
  print("masks.shape", masks.shape)
138
 
139
  mask = masks.squeeze().astype(bool)
 
148
 
149
  print(f"Saved {len(frames)} temporary frames")
150
 
151
+ inference_state = video_predictor.init_state(video_path=temp_dir)
152
+ _, _, _ = video_predictor.add_new_mask(
153
+ inference_state=inference_state,
154
+ frame_idx=0,
155
+ obj_id=1,
156
+ mask=mask
157
+ )
158
+
159
+ video_segments = {}
160
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
161
+ video_segments[out_frame_idx] = {
162
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
163
+ for i, out_obj_id in enumerate(out_obj_ids)
164
+ }
 
165
 
166
  print('Segmenting for main vid done')
167
  print(f"Number of segmented frames: {len(video_segments)}")
 
218
  demo = gr.Interface(
219
  fn=segment_video,
220
  inputs=[
221
+ gr.Video(label="Upload Video (Keep it under 10 seconds for this demo)"),
222
+ gr.Textbox(label="Enter text prompt for object detection")
223
  ],
224
  outputs=gr.Video(label="Segmented Video"),
225
+ title="Text-Prompted Video Object Segmentation",
226
+ description="""
227
+ This demo uses [Florence-2](https://huggingface.co/microsoft/Florence-2-large), a vision-language model, to enable text-prompted object detection for [SAM2](https://github.com/facebookresearch/segment-anything).
228
+ Florence-2 interprets your text prompt, allowing SAM2 to segment the described object in the video.
229
+
230
+ 1. Upload a short video (< 10 sec)
231
+ 2. Describe the object to segment
232
+ 3. Get your segmented video!
233
+ """
234
  )
235
 
236
  demo.launch()