Spaces:
Runtime error
Runtime error
import os | |
import spaces | |
import time | |
try: | |
token =os.environ['HF_TOKEN'] | |
except: | |
print("paste your hf token here!") | |
token = "hf_xxxxxxxxxxxxxxxxxxx" | |
os.environ['HF_TOKEN'] = token | |
import torch | |
import gradio as gr | |
from gradio.themes.utils import colors, fonts, sizes | |
from faster_whisper import WhisperModel | |
from moviepy.editor import VideoFileClip | |
from transformers import AutoTokenizer, AutoModel | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
# ======================================== | |
# Model Initialization | |
# ======================================== | |
if gr.NO_RELOAD: | |
if torch.cuda.is_available(): | |
speech_model = WhisperModel("large-v3", device="cuda", compute_type="float16") | |
else: | |
speech_model = WhisperModel("large-v3", device="cpu") | |
model_path = 'OpenGVLab/VideoChat-Flash-Qwen2-7B_res448' | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda() | |
model.config.mm_llm_compress = False | |
# ======================================== | |
# Define Utils | |
# ======================================== | |
def extract_audio(name): | |
with VideoFileClip(name) as video: | |
if video.audio == None: | |
return None | |
audio = video.audio | |
audio_name = name[:-4] + '.wav' | |
audio.write_audiofile(audio_name, fps=16000) | |
return audio_name | |
def audio2text(audio): | |
segments, _ = speech_model.transcribe(audio) | |
text = "" | |
for segment in segments: | |
# print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) | |
text += ("[%.2fs -> %.2fs] %s " % (segment.start, segment.end, segment.text)) | |
# print(text) | |
return text | |
# ======================================== | |
# Gradio Setting | |
# ======================================== | |
def gradio_reset(): | |
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False), gr.update(interactive=False) , gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), [], "" | |
def upload_video(gr_video, text_input="Type and press Enter"): | |
if gr_video is None: | |
return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), "" | |
# if check_asr: #表示需要提取音频 | |
audio_name = extract_audio(gr_video) | |
if audio_name != None: | |
asr_msg = audio2text(audio_name) | |
else: | |
asr_msg = "" | |
# else: | |
# asr_msg = "" | |
return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder=text_input), gr.update(value="Start Chatting", interactive=False), asr_msg | |
def clear_(): | |
return [], [] | |
def gradio_ask(user_message, chatbot): | |
# if len(user_message) == 0: | |
# return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot | |
chatbot = chatbot + [[user_message, None]] | |
return user_message, chatbot | |
def gradio_answer(chatbot, text_input, video_path, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, do_sample, num_beams, top_p, temperature): | |
if chat_state is None or len(chat_state) == 0: | |
if asr_msg is None or len(asr_msg) == 0: | |
# text_input = f"Answer the question based on the video content.\n{text_input}" | |
pass | |
elif check_asr: | |
text_input = f"The speech extracted from the video via ASR is as follows: {asr_msg}\n{text_input}" | |
print(f"\033[91m== text_input: \033[0m\n{text_input}\n") | |
response, chat_state = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=text_input, chat_history=chat_state, return_history=True, max_num_frames=max_num_frames, generation_config={ | |
'max_new_tokens': max_new_tokens, 'do_sample':do_sample, | |
'num_beams':num_beams, 'top_p':top_p, 'temperature':temperature | |
}) | |
current_response = "" | |
for char in response: | |
current_response += char | |
chatbot[-1][1] = current_response + "▌" | |
yield chatbot, chat_state | |
time.sleep(0.008) | |
chatbot[-1][1] = current_response | |
yield chatbot, chat_state | |
class OpenGVLab(gr.themes.base.Base): | |
def __init__( | |
self, | |
*, | |
primary_hue=colors.blue, | |
secondary_hue=colors.sky, | |
neutral_hue=colors.gray, | |
spacing_size=sizes.spacing_md, | |
radius_size=sizes.radius_sm, | |
text_size=sizes.text_md, | |
font=( | |
fonts.GoogleFont("Noto Sans"), | |
"ui-sans-serif", | |
"sans-serif", | |
), | |
font_mono=( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-monospace", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
super().set( | |
body_background_fill="*neutral_50", | |
) | |
gvlabtheme = OpenGVLab(primary_hue=colors.blue, | |
secondary_hue=colors.sky, | |
neutral_hue=colors.gray, | |
spacing_size=sizes.spacing_md, | |
radius_size=sizes.radius_sm, | |
text_size=sizes.text_md, | |
) | |
title = """<h1 align="center"><a href="https://github.com/OpenGVLab/VideoChat-Flash"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="VideoChat-Flash" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>""" | |
description =""" | |
VideoChat-Flash-7B@448 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/VideoChat-Flash'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p> | |
""" | |
with gr.Blocks(title="VideoChat-Flash",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
# with gr.Row(): | |
# # options_yes_no = ["YES", "NO"] | |
# # with gr.Row(): | |
# # radio_type = gr.Radio(choices=options_1, label="VideoChat-Flash", value=options_1[0]) | |
# with gr.Row(): | |
with gr.Row(): | |
with gr.Column(scale=0.5, visible=True) as video_upload: | |
with gr.Column(elem_id="image", scale=0.5) as img_part: | |
up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload") | |
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
restart = gr.Button("Restart") | |
max_num_frames = gr.Slider( | |
minimum=4, | |
maximum=1024, | |
value=512, | |
step=4, | |
interactive=True, | |
label="Max Input Frames", | |
) | |
max_new_tokens = gr.Slider( | |
minimum=1, | |
maximum=4096, | |
value=1024, | |
step=1, | |
interactive=True, | |
label="Max Output Tokens", | |
) | |
check_asr = gr.Checkbox(label="Use ASR", info="Whether to extract speech using ASR.") | |
check_do_sample = gr.Checkbox(label="Do Sample", info="Whether to do sample during decoding.") | |
num_beams = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=1, | |
step=1, | |
interactive=True, | |
visible=False, | |
label="beam search numbers)", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.1, | |
step=0.1, | |
visible=False, | |
interactive=True, label="Top_P", | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.1, | |
step=0.1, | |
visible=False, | |
interactive=True, label="Temperature", | |
) | |
def toggle_slide(is_checked): | |
return gr.update(visible=is_checked), gr.update(visible=is_checked), gr.update(visible=is_checked) | |
check_do_sample.select(fn=toggle_slide, inputs=check_do_sample, outputs=[num_beams, top_p, temperature]) | |
with gr.Column(visible=True) as input_raws: | |
chat_state = gr.State([]) | |
asr_msg = gr.State() | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label='VideoChat', | |
avatar_images=[ | |
"human.jpg", # 用户头像 | |
"assistant.png", # AI头像 | |
]) | |
with gr.Row(): | |
with gr.Column(scale=0.7): | |
text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False) | |
with gr.Column(scale=0.15, min_width=0): | |
run = gr.Button("💭Send", interactive=False) | |
with gr.Column(scale=0.15, min_width=0): | |
clear = gr.Button("🔄Clear️", interactive=False) | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=[ | |
["demo_videos/basketball.mp4", False, "Describe this video in detail."], | |
["demo_videos/cup1.mp4", False, "Describe this video in detail."], | |
["demo_videos/dog.mp4", False, "Describe this video in detail."], | |
], | |
inputs = [up_video, text_input], | |
outputs = [run, clear, up_video, text_input, upload_button, asr_msg], | |
fn=upload_video, | |
run_on_click=True | |
) | |
up_video.clear(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False) | |
upload_button.click(upload_video, [up_video], [run, clear, up_video, text_input, upload_button, asr_msg]) | |
text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state] | |
).then(lambda: "", None, text_input) | |
run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state] | |
).then(lambda: "", None, text_input) | |
clear.click(clear_, None, [chatbot, chat_state]) | |
restart.click(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False) | |
demo.launch(server_name='0.0.0.0',server_port=7864) |