shaoyent commited on
Commit
ec2fb01
·
1 Parent(s): b345be7
Files changed (2) hide show
  1. app.py +25 -14
  2. bridgetower_custom.py +2 -2
app.py CHANGED
@@ -87,7 +87,7 @@ def time_to_frame(time, fps):
87
  '''
88
  convert time in seconds into frame number
89
  '''
90
- return time * fps - 1
91
 
92
  def str2time(strtime):
93
  strtime = strtime.strip('"')
@@ -105,7 +105,7 @@ def collate_fn(batch_list):
105
  batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
106
  return batch
107
 
108
- def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2):
109
  if os.path.exists(os.path.join(output, 'embeddings.pkl')):
110
  return
111
 
@@ -123,7 +123,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
123
  # Get the total numer of frames in the video.
124
  frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
125
 
126
- print(fps, frame_count)
127
 
128
  frame_number = 0
129
 
@@ -132,8 +132,9 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
132
 
133
  embeddings = []
134
  batch_list = []
 
135
 
136
- for idx, caption in enumerate(webvtt.read(subtitles)):
137
  st_time = str2time(caption.start)
138
  ed_time = str2time(caption.end)
139
 
@@ -144,9 +145,10 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
144
  raise NotImplementedError
145
 
146
  frame_no = time_to_frame(mid_time, fps)
147
-
 
 
148
  print('Read a new frame: ', idx, mid_time, frame_no, text)
149
- vidcap.set(1, frame_no) # added this line
150
  success, frame = vidcap.read()
151
  if success:
152
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
@@ -161,7 +163,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
161
  'image_id': idx,
162
  'img_fname': img_fname,
163
  'caption': text,
164
- 'time': mid_time,
165
  'frame_no': frame_no
166
  })
167
 
@@ -169,6 +171,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
169
  encoding['text'] = text
170
  encoding['image_filepath'] = img_fpath
171
  encoding['start_time'] = caption.start
 
172
 
173
  batch_list.append(encoding)
174
 
@@ -186,7 +189,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
186
  'text': batch_list[i]['text'],
187
  'image_filepath': batch_list[i]['image_filepath'],
188
  'start_time': batch_list[i]['start_time'],
189
- 'frame_no': frame_no,
190
  })
191
  batch_list = []
192
 
@@ -201,9 +204,11 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
201
  'text': batch_list[i]['text'],
202
  'image_filepath': batch_list[i]['image_filepath'],
203
  'start_time': batch_list[i]['start_time'],
204
- 'frame_no': frame_no,
205
  })
206
 
 
 
207
  with open(os.path.join(output, 'annotations.json'), 'w') as fh:
208
  json.dump(anno, fh)
209
 
@@ -240,10 +245,14 @@ def run_query(video_path, text_query, path='/tmp'):
240
  clip_images = []
241
  transcripts = []
242
  for idx in I[0]:
243
- frame_no = embeddings[idx]['frame_no']
244
- vidcap.set(1, frame_no) # added this line
 
 
 
245
  success, frame = vidcap.read()
246
  if success:
 
247
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
248
  frame = Image.fromarray(frame)
249
  clip_images.append(frame)
@@ -277,7 +286,7 @@ def get_video_id_from_url(video_url):
277
  return None
278
 
279
 
280
- def process(video_url, text_query):
281
  tmp_dir = os.environ.get('TMPDIR', '/tmp')
282
  video_id = get_video_id_from_url(video_url)
283
  output_dir = os.path.join(tmp_dir, video_id)
@@ -289,6 +298,7 @@ def process(video_url, text_query):
289
  output=output_dir,
290
  expanded=False,
291
  batch_size=8,
 
292
  )
293
  frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
294
  return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
@@ -311,8 +321,8 @@ with gr.Blocks() as demo:
311
  gr.Examples(
312
  examples=[
313
  ['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
314
- ['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake on floor'],
315
- ['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'cat woman'],
316
  ['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
317
  ],
318
  inputs=[video_url, text_query],
@@ -324,6 +334,7 @@ with gr.Blocks() as demo:
324
  )
325
 
326
  try:
 
327
  demo.launch(share=True)
328
  except:
329
  demo.launch()
 
87
  '''
88
  convert time in seconds into frame number
89
  '''
90
+ return int(time * fps - 1)
91
 
92
  def str2time(strtime):
93
  strtime = strtime.strip('"')
 
105
  batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
106
  return batch
107
 
108
+ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2, progress=gr.Progress()):
109
  if os.path.exists(os.path.join(output, 'embeddings.pkl')):
110
  return
111
 
 
123
  # Get the total numer of frames in the video.
124
  frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
125
 
126
+ # print(fps, frame_count)
127
 
128
  frame_number = 0
129
 
 
132
 
133
  embeddings = []
134
  batch_list = []
135
+ vtt = webvtt.read(subtitles)
136
 
137
+ for idx, caption in progress.tqdm(enumerate(vtt), total=vtt.total_length, desc="Generating embeddings"):
138
  st_time = str2time(caption.start)
139
  ed_time = str2time(caption.end)
140
 
 
145
  raise NotImplementedError
146
 
147
  frame_no = time_to_frame(mid_time, fps)
148
+ mid_time_ms = mid_time * 1000
149
+ # vidcap.set(1, frame_no) # added this line
150
+ vidcap.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
151
  print('Read a new frame: ', idx, mid_time, frame_no, text)
 
152
  success, frame = vidcap.read()
153
  if success:
154
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
163
  'image_id': idx,
164
  'img_fname': img_fname,
165
  'caption': text,
166
+ 'time': mid_time_ms,
167
  'frame_no': frame_no
168
  })
169
 
 
171
  encoding['text'] = text
172
  encoding['image_filepath'] = img_fpath
173
  encoding['start_time'] = caption.start
174
+ encoding['time'] = mid_time_ms
175
 
176
  batch_list.append(encoding)
177
 
 
189
  'text': batch_list[i]['text'],
190
  'image_filepath': batch_list[i]['image_filepath'],
191
  'start_time': batch_list[i]['start_time'],
192
+ 'time': batch_list[i]['time'],
193
  })
194
  batch_list = []
195
 
 
204
  'text': batch_list[i]['text'],
205
  'image_filepath': batch_list[i]['image_filepath'],
206
  'start_time': batch_list[i]['start_time'],
207
+ 'time': batch_list[i]['time'],
208
  })
209
 
210
+ batch_list = []
211
+
212
  with open(os.path.join(output, 'annotations.json'), 'w') as fh:
213
  json.dump(anno, fh)
214
 
 
245
  clip_images = []
246
  transcripts = []
247
  for idx in I[0]:
248
+ # frame_no = embeddings[idx]['frame_no']
249
+ # vidcap.set(1, frame_no) # added this line
250
+ frame_timestamp = embeddings[idx]['time']
251
+ vidcap.set(cv2.CAP_PROP_POS_MSEC, frame_timestamp)
252
+
253
  success, frame = vidcap.read()
254
  if success:
255
+ frame = maintain_aspect_ratio_resize(frame, height=400)
256
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
257
  frame = Image.fromarray(frame)
258
  clip_images.append(frame)
 
286
  return None
287
 
288
 
289
+ def process(video_url, text_query, progress=gr.Progress()):
290
  tmp_dir = os.environ.get('TMPDIR', '/tmp')
291
  video_id = get_video_id_from_url(video_url)
292
  output_dir = os.path.join(tmp_dir, video_id)
 
298
  output=output_dir,
299
  expanded=False,
300
  batch_size=8,
301
+ progress=gr.Progress(),
302
  )
303
  frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
304
  return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
 
321
  gr.Examples(
322
  examples=[
323
  ['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
324
+ ['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake'],
325
+ ['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'bunny'],
326
  ['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
327
  ],
328
  inputs=[video_url, text_query],
 
334
  )
335
 
336
  try:
337
+ demo.queue(concurrency_count=3)
338
  demo.launch(share=True)
339
  except:
340
  demo.launch()
bridgetower_custom.py CHANGED
@@ -96,8 +96,8 @@ class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel):
96
  labels: Optional[torch.LongTensor] = None,
97
  ):
98
 
99
- outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask)
100
- final_hidden_cls = outputs.last_hidden_state[:,0,:]
101
  final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
102
 
103
  return final_hidden_cls
 
96
  labels: Optional[torch.LongTensor] = None,
97
  ):
98
 
99
+ outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
100
+ final_hidden_cls = outputs.hidden_states[-1][:,0,:]
101
  final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
102
 
103
  return final_hidden_cls