lixinhao commited on
Commit
3e81358
·
verified ·
1 Parent(s): dc71015

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -60
app.py CHANGED
@@ -1,64 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+
2
+ import os
3
+ import spaces
4
+ import time
5
+ try:
6
+ token =os.environ['HF_TOKEN']
7
+ except:
8
+ print("paste your hf token here!")
9
+ token = "hf_xxxxxxxxxxxxxxxxxxx"
10
+ os.environ['HF_TOKEN'] = token
11
+ import torch
12
  import gradio as gr
13
+ from gradio.themes.utils import colors, fonts, sizes
14
+
15
+ from faster_whisper import WhisperModel
16
+ from moviepy.editor import VideoFileClip
17
+ from transformers import AutoTokenizer, AutoModel
18
+
19
+ # ========================================
20
+ # Model Initialization
21
+ # ========================================
22
+
23
+ if gr.NO_RELOAD:
24
+ if torch.cuda.is_available():
25
+ speech_model = WhisperModel("large-v3", device="cuda", compute_type="float16")
26
+ else:
27
+ speech_model = WhisperModel("large-v3", device="cpu")
28
+
29
+ model_path = 'OpenGVLab/VideoChat-Flash-Qwen2-7B_res448'
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
32
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
33
+
34
+
35
+ model.config.mm_llm_compress = False
36
+
37
+
38
+ # ========================================
39
+ # Define Utils
40
+ # ========================================
41
+
42
+
43
+ def extract_audio(name):
44
+ with VideoFileClip(name) as video:
45
+ if video.audio == None:
46
+ return None
47
+ audio = video.audio
48
+ audio_name = name[:-4] + '.wav'
49
+ audio.write_audiofile(audio_name, fps=16000)
50
+ return audio_name
51
+
52
+ @spaces.GPU
53
+ def audio2text(audio):
54
+ segments, _ = speech_model.transcribe(audio)
55
+ text = ""
56
+ for segment in segments:
57
+ # print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
58
+ text += ("[%.2fs -> %.2fs] %s " % (segment.start, segment.end, segment.text))
59
+ # print(text)
60
+ return text
61
+
62
+
63
+ # ========================================
64
+ # Gradio Setting
65
+ # ========================================
66
+ def gradio_reset():
67
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False), gr.update(interactive=False) , gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), [], ""
68
+
69
+
70
+
71
+
72
+ def upload_video(gr_video, text_input="Type and press Enter"):
73
+ if gr_video is None:
74
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), ""
75
+
76
+
77
+ # if check_asr: #表示需要提取音频
78
+ audio_name = extract_audio(gr_video)
79
+ if audio_name != None:
80
+ asr_msg = audio2text(audio_name)
81
+ else:
82
+ asr_msg = ""
83
+ # else:
84
+ # asr_msg = ""
85
+
86
+ return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder=text_input), gr.update(value="Start Chatting", interactive=False), asr_msg
87
+
88
+ def clear_():
89
+ return [], []
90
+
91
+ def gradio_ask(user_message, chatbot):
92
+ # if len(user_message) == 0:
93
+ # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot
94
+ chatbot = chatbot + [[user_message, None]]
95
+ return user_message, chatbot
96
+
97
+ @spaces.GPU
98
+ def gradio_answer(chatbot, text_input, video_path, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, do_sample, num_beams, top_p, temperature):
99
+
100
+
101
+
102
+ if chat_state is None or len(chat_state) == 0:
103
+ if asr_msg is None or len(asr_msg) == 0:
104
+ # text_input = f"Answer the question based on the video content.\n{text_input}"
105
+ pass
106
+ elif check_asr:
107
+ text_input = f"The speech extracted from the video via ASR is as follows: {asr_msg}\n{text_input}"
108
+
109
+ print(f"\033[91m== text_input: \033[0m\n{text_input}\n")
110
+
111
+ response, chat_state = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=text_input, chat_history=chat_state, return_history=True, max_num_frames=max_num_frames, generation_config={
112
+ 'max_new_tokens': max_new_tokens, 'do_sample':do_sample,
113
+ 'num_beams':num_beams, 'top_p':top_p, 'temperature':temperature
114
+ })
115
+
116
+ current_response = ""
117
+
118
+ for char in response:
119
+ current_response += char
120
+ chatbot[-1][1] = current_response + "▌"
121
+ yield chatbot, chat_state
122
+ time.sleep(0.008)
123
+ chatbot[-1][1] = current_response
124
+ yield chatbot, chat_state
125
+
126
+
127
+ class OpenGVLab(gr.themes.base.Base):
128
+ def __init__(
129
+ self,
130
+ *,
131
+ primary_hue=colors.blue,
132
+ secondary_hue=colors.sky,
133
+ neutral_hue=colors.gray,
134
+ spacing_size=sizes.spacing_md,
135
+ radius_size=sizes.radius_sm,
136
+ text_size=sizes.text_md,
137
+ font=(
138
+ fonts.GoogleFont("Noto Sans"),
139
+ "ui-sans-serif",
140
+ "sans-serif",
141
+ ),
142
+ font_mono=(
143
+ fonts.GoogleFont("IBM Plex Mono"),
144
+ "ui-monospace",
145
+ "monospace",
146
  ),
147
+ ):
148
+ super().__init__(
149
+ primary_hue=primary_hue,
150
+ secondary_hue=secondary_hue,
151
+ neutral_hue=neutral_hue,
152
+ spacing_size=spacing_size,
153
+ radius_size=radius_size,
154
+ text_size=text_size,
155
+ font=font,
156
+ font_mono=font_mono,
157
+ )
158
+ super().set(
159
+ body_background_fill="*neutral_50",
160
+ )
161
+
162
+
163
+ gvlabtheme = OpenGVLab(primary_hue=colors.blue,
164
+ secondary_hue=colors.sky,
165
+ neutral_hue=colors.gray,
166
+ spacing_size=sizes.spacing_md,
167
+ radius_size=sizes.radius_sm,
168
+ text_size=sizes.text_md,
169
+ )
170
+
171
+ title = """<h1 align="center"><a href="https://github.com/OpenGVLab/VideoChat-Flash"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="VideoChat-Flash" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>"""
172
+ description ="""
173
+ VideoChat-Flash-7B@448 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/VideoChat-Flash'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
174
+ """
175
+
176
+
177
+ with gr.Blocks(title="VideoChat-Flash",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
178
+ gr.Markdown(title)
179
+ gr.Markdown(description)
180
+ # with gr.Row():
181
+ # # options_yes_no = ["YES", "NO"]
182
+ # # with gr.Row():
183
+ # # radio_type = gr.Radio(choices=options_1, label="VideoChat-Flash", value=options_1[0])
184
+ # with gr.Row():
185
+
186
+ with gr.Row():
187
+ with gr.Column(scale=0.5, visible=True) as video_upload:
188
+ with gr.Column(elem_id="image", scale=0.5) as img_part:
189
+ up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload")
190
+
191
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
192
+ restart = gr.Button("Restart")
193
+
194
+ max_num_frames = gr.Slider(
195
+ minimum=4,
196
+ maximum=1024,
197
+ value=512,
198
+ step=4,
199
+ interactive=True,
200
+ label="Max Input Frames",
201
+ )
202
+
203
+ max_new_tokens = gr.Slider(
204
+ minimum=1,
205
+ maximum=4096,
206
+ value=1024,
207
+ step=1,
208
+ interactive=True,
209
+ label="Max Output Tokens",
210
+ )
211
+
212
+ check_asr = gr.Checkbox(label="Use ASR", info="Whether to extract speech using ASR.")
213
+ check_do_sample = gr.Checkbox(label="Do Sample", info="Whether to do sample during decoding.")
214
+
215
+ num_beams = gr.Slider(
216
+ minimum=1,
217
+ maximum=10,
218
+ value=1,
219
+ step=1,
220
+ interactive=True,
221
+ visible=False,
222
+ label="beam search numbers)",
223
+ )
224
+
225
+ top_p = gr.Slider(
226
+ minimum=0.0,
227
+ maximum=1.0,
228
+ value=0.1,
229
+ step=0.1,
230
+ visible=False,
231
+ interactive=True, label="Top_P",
232
+ )
233
+
234
+ temperature = gr.Slider(
235
+ minimum=0.1,
236
+ maximum=2.0,
237
+ value=0.1,
238
+ step=0.1,
239
+ visible=False,
240
+ interactive=True, label="Temperature",
241
+ )
242
+
243
+ def toggle_slide(is_checked):
244
+ return gr.update(visible=is_checked), gr.update(visible=is_checked), gr.update(visible=is_checked)
245
+
246
+ check_do_sample.select(fn=toggle_slide, inputs=check_do_sample, outputs=[num_beams, top_p, temperature])
247
+
248
+ with gr.Column(visible=True) as input_raws:
249
+ chat_state = gr.State([])
250
+ asr_msg = gr.State()
251
+ chatbot = gr.Chatbot(
252
+ elem_id="chatbot",
253
+ label='VideoChat',
254
+ avatar_images=[
255
+ "avatar/human.jpg", # 用户头像
256
+ "avatar/assistant.png", # AI头像
257
+ ])
258
+ with gr.Row():
259
+ with gr.Column(scale=0.7):
260
+ text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False)
261
+ with gr.Column(scale=0.15, min_width=0):
262
+ run = gr.Button("💭Send", interactive=False)
263
+ with gr.Column(scale=0.15, min_width=0):
264
+ clear = gr.Button("🔄Clear️", interactive=False)
265
+ with gr.Row():
266
+ examples = gr.Examples(
267
+ examples=[
268
+ ["demo_videos/basketball.mp4", False, "Describe this video in detail."],
269
+ ["demo_videos/cup1.mp4", False, "Describe this video in detail."],
270
+ ["demo_videos/dog.mp4", False, "Describe this video in detail."],
271
+ ],
272
+ inputs = [up_video, text_input],
273
+ outputs = [run, clear, up_video, text_input, upload_button, asr_msg],
274
+ fn=upload_video,
275
+ run_on_click=True
276
+ )
277
+
278
+ up_video.clear(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False)
279
+
280
+ upload_button.click(upload_video, [up_video], [run, clear, up_video, text_input, upload_button, asr_msg])
281
+
282
+
283
+ text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
284
+ gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state]
285
+ ).then(lambda: "", None, text_input)
286
+
287
+ run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
288
+ gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state]
289
+ ).then(lambda: "", None, text_input)
290
 
291
+ clear.click(clear_, None, [chatbot, chat_state])
292
+ restart.click(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False)
293
 
294
+ demo.launch(server_name='0.0.0.0',server_port=7864)