merve HF staff commited on
Commit
e147ef4
·
verified ·
1 Parent(s): f78ccb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -62
app.py CHANGED
@@ -29,93 +29,67 @@ def sample_frames(video_file, num_frames):
29
  frames.append(pil_img)
30
  video.release()
31
  return frames
32
-
33
  @spaces.GPU
34
  def bot_streaming(message, history):
35
 
36
- txt = message["text"]
37
- ext_buffer = f"USER: {txt} ASSISTANT: "
38
 
39
- if message["files"]:
40
- if len(message["files"]) == 1:
41
  image = [message.files[0].path]
42
  # interleaved images or video
43
- elif len(message["files"]) > 1:
44
- image = [msg["path"] for msg in message["files"]]
45
  else:
46
-
47
- def has_file_data(lst):
48
- return any(isinstance(item, FileData) for sublist in lst if isinstance(sublist, tuple) for item in sublist)
49
-
50
- def extract_paths(lst):
51
- return [item["path"] for sublist in lst if isinstance(sublist, tuple) for item in sublist if isinstance(item, FileData)]
52
-
53
- latest_text_only_index = -1
54
 
55
- for i, item in enumerate(history):
56
- if all(isinstance(sub_item, str) for sub_item in item):
57
- latest_text_only_index = i
58
-
59
- image = [path for i, item in enumerate(history) if i < latest_text_only_index and has_file_data(item) for path in extract_paths(item)]
60
-
61
- if message["files"] is None:
62
  gr.Error("You need to upload an image or video for LLaVA to work.")
63
 
64
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
65
  image_extensions = Image.registered_extensions()
66
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
67
- image_list = []
68
- video_list = []
69
-
70
- print("media", image)
71
  if len(image) == 1:
72
  if image[0].endswith(video_extensions):
73
 
74
- video_list = sample_frames(image[0], 12)
75
-
76
- prompt = f"USER: <video> {message.text} ASSISTANT:"
77
  elif image[0].endswith(image_extensions):
78
- image_list.append(Image.open(image[0]).convert("RGB"))
79
- msg = message["text"]
80
- prompt = f"USER: <image> {message.text} ASSISTANT:"
81
 
82
  elif len(image) > 1:
83
- user_prompt = message["text"]
 
84
 
85
  for img in image:
86
  if img.endswith(image_extensions):
87
  img = Image.open(img).convert("RGB")
88
  image_list.append(img)
89
 
90
- elif img.endswith(video_extensions):
91
- video_list.append(sample_frames(img, 7))
92
- #for frame in sample_frames(img, 6):
93
- #video_list.append(frame)
94
-
95
- image_tokens = ""
96
- video_tokens = ""
97
-
98
- if image_list != []:
99
- image_tokens = "<image>" * len(image_list)
100
- if video_list != []:
101
 
102
- toks = len(video_list)
103
- video_tokens = "<video>" * toks
104
-
105
-
106
 
107
- prompt = f"USER: {image_tokens}{video_tokens} {user_prompt} ASSISTANT:"
 
108
 
109
- if image_list != [] and video_list != []:
110
- inputs = processor(text=prompt, images=image_list, videos=video_list, padding=True, return_tensors="pt").to("cuda",torch.float16)
111
- elif image_list != [] and video_list == []:
112
- inputs = processor(text=prompt, images=image_list, padding=True, return_tensors="pt").to("cuda", torch.float16)
113
- elif image_list == [] and video_list != []:
114
- inputs = processor(text=prompt, videos=video_list, padding=True, return_tensors="pt").to("cuda", torch.float16)
115
-
116
-
117
- streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True, "clean_up_tokenization_spaces":True})
118
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)
119
  generated_text = ""
120
 
121
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -127,10 +101,10 @@ def bot_streaming(message, history):
127
  for new_text in streamer:
128
 
129
  buffer += new_text
130
- print("new_text", new_text)
131
- #generated_text_without_prompt = buffer[len(ext_buffer):][:-1]
132
  time.sleep(0.01)
133
- yield buffer #generated_text_without_prompt
134
 
135
 
136
  demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Onevision", examples=[
 
29
  frames.append(pil_img)
30
  video.release()
31
  return frames
32
+
33
  @spaces.GPU
34
  def bot_streaming(message, history):
35
 
36
+ txt = message.text
37
+ ext_buffer = f"user\n{txt} assistant"
38
 
39
+ if message.files:
40
+ if len(message.files) == 1:
41
  image = [message.files[0].path]
42
  # interleaved images or video
43
+ elif len(message.files) > 1:
44
+ image = [msg.path for msg in message.files]
45
  else:
46
+ # if there's no image uploaded for this turn, look for images in the past turns
47
+ # kept inside tuples, take the last one
48
+ for hist in history:
49
+ if type(hist[0])==tuple:
50
+ image = hist[0][0]
 
 
 
51
 
52
+ if message.files is None:
 
 
 
 
 
 
53
  gr.Error("You need to upload an image or video for LLaVA to work.")
54
 
55
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
56
  image_extensions = Image.registered_extensions()
57
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
 
 
 
 
58
  if len(image) == 1:
59
  if image[0].endswith(video_extensions):
60
 
61
+ video = sample_frames(image[0], 32)
62
+ image = None
63
+ prompt = f"<|im_start|>user <video>\n{message.text}<|im_end|><|im_start|>assistant"
64
  elif image[0].endswith(image_extensions):
65
+ image = Image.open(image[0]).convert("RGB")
66
+ video = None
67
+ prompt = f"<|im_start|>user <image>\n{message.text}<|im_end|><|im_start|>assistant"
68
 
69
  elif len(image) > 1:
70
+ image_list = []
71
+ user_prompt = message.text
72
 
73
  for img in image:
74
  if img.endswith(image_extensions):
75
  img = Image.open(img).convert("RGB")
76
  image_list.append(img)
77
 
78
+ elif img.endswith(video_extensions):
79
+ frames = sample_frames(img, 6)
80
+ for frame in frames:
81
+ image_list.append(frame)
 
 
 
 
 
 
 
82
 
83
+ toks = "<image>" * len(image_list)
84
+ prompt = "<|im_start|>user"+ toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
 
 
85
 
86
+ image = image_list
87
+ video = None
88
 
89
+
90
+ inputs = processor(text=prompt, images=image, videos=video, return_tensors="pt").to("cuda", torch.float16)
91
+ streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True})
92
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
 
 
 
 
 
 
93
  generated_text = ""
94
 
95
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
101
  for new_text in streamer:
102
 
103
  buffer += new_text
104
+
105
+ generated_text_without_prompt = buffer[len(ext_buffer):]
106
  time.sleep(0.01)
107
+ yield generated_text_without_prompt
108
 
109
 
110
  demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Onevision", examples=[