VideoChat-Flash / app.py
lixinhao's picture
Update app.py
432c607 verified
import os
import spaces
import time
try:
token =os.environ['HF_TOKEN']
except:
print("paste your hf token here!")
token = "hf_xxxxxxxxxxxxxxxxxxx"
os.environ['HF_TOKEN'] = token
import torch
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes
from faster_whisper import WhisperModel
from moviepy.editor import VideoFileClip
from transformers import AutoTokenizer, AutoModel
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# ========================================
# Model Initialization
# ========================================
if gr.NO_RELOAD:
if torch.cuda.is_available():
speech_model = WhisperModel("large-v3", device="cuda", compute_type="float16")
else:
speech_model = WhisperModel("large-v3", device="cpu")
model_path = 'OpenGVLab/VideoChat-Flash-Qwen2-7B_res448'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model.config.mm_llm_compress = False
# ========================================
# Define Utils
# ========================================
def extract_audio(name):
with VideoFileClip(name) as video:
if video.audio == None:
return None
audio = video.audio
audio_name = name[:-4] + '.wav'
audio.write_audiofile(audio_name, fps=16000)
return audio_name
@spaces.GPU
def audio2text(audio):
segments, _ = speech_model.transcribe(audio)
text = ""
for segment in segments:
# print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
text += ("[%.2fs -> %.2fs] %s " % (segment.start, segment.end, segment.text))
# print(text)
return text
# ========================================
# Gradio Setting
# ========================================
def gradio_reset():
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False), gr.update(interactive=False) , gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), [], ""
def upload_video(gr_video, text_input="Type and press Enter"):
if gr_video is None:
return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), ""
# if check_asr: #表示需要提取音频
audio_name = extract_audio(gr_video)
if audio_name != None:
asr_msg = audio2text(audio_name)
else:
asr_msg = ""
# else:
# asr_msg = ""
return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder=text_input), gr.update(value="Start Chatting", interactive=False), asr_msg
def clear_():
return [], []
def gradio_ask(user_message, chatbot):
# if len(user_message) == 0:
# return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot
chatbot = chatbot + [[user_message, None]]
return user_message, chatbot
@spaces.GPU
def gradio_answer(chatbot, text_input, video_path, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, do_sample, num_beams, top_p, temperature):
if chat_state is None or len(chat_state) == 0:
if asr_msg is None or len(asr_msg) == 0:
# text_input = f"Answer the question based on the video content.\n{text_input}"
pass
elif check_asr:
text_input = f"The speech extracted from the video via ASR is as follows: {asr_msg}\n{text_input}"
print(f"\033[91m== text_input: \033[0m\n{text_input}\n")
response, chat_state = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=text_input, chat_history=chat_state, return_history=True, max_num_frames=max_num_frames, generation_config={
'max_new_tokens': max_new_tokens, 'do_sample':do_sample,
'num_beams':num_beams, 'top_p':top_p, 'temperature':temperature
})
current_response = ""
for char in response:
current_response += char
chatbot[-1][1] = current_response + "▌"
yield chatbot, chat_state
time.sleep(0.008)
chatbot[-1][1] = current_response
yield chatbot, chat_state
class OpenGVLab(gr.themes.base.Base):
def __init__(
self,
*,
primary_hue=colors.blue,
secondary_hue=colors.sky,
neutral_hue=colors.gray,
spacing_size=sizes.spacing_md,
radius_size=sizes.radius_sm,
text_size=sizes.text_md,
font=(
fonts.GoogleFont("Noto Sans"),
"ui-sans-serif",
"sans-serif",
),
font_mono=(
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
body_background_fill="*neutral_50",
)
gvlabtheme = OpenGVLab(primary_hue=colors.blue,
secondary_hue=colors.sky,
neutral_hue=colors.gray,
spacing_size=sizes.spacing_md,
radius_size=sizes.radius_sm,
text_size=sizes.text_md,
)
title = """<h1 align="center"><a href="https://github.com/OpenGVLab/VideoChat-Flash"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="VideoChat-Flash" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>"""
description ="""
VideoChat-Flash-7B@448 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/VideoChat-Flash'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
"""
with gr.Blocks(title="VideoChat-Flash",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
gr.Markdown(title)
gr.Markdown(description)
# with gr.Row():
# # options_yes_no = ["YES", "NO"]
# # with gr.Row():
# # radio_type = gr.Radio(choices=options_1, label="VideoChat-Flash", value=options_1[0])
# with gr.Row():
with gr.Row():
with gr.Column(scale=0.5, visible=True) as video_upload:
with gr.Column(elem_id="image", scale=0.5) as img_part:
up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
restart = gr.Button("Restart")
max_num_frames = gr.Slider(
minimum=4,
maximum=1024,
value=512,
step=4,
interactive=True,
label="Max Input Frames",
)
max_new_tokens = gr.Slider(
minimum=1,
maximum=4096,
value=1024,
step=1,
interactive=True,
label="Max Output Tokens",
)
check_asr = gr.Checkbox(label="Use ASR", info="Whether to extract speech using ASR.")
check_do_sample = gr.Checkbox(label="Do Sample", info="Whether to do sample during decoding.")
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
visible=False,
label="beam search numbers)",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.1,
visible=False,
interactive=True, label="Top_P",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.1,
step=0.1,
visible=False,
interactive=True, label="Temperature",
)
def toggle_slide(is_checked):
return gr.update(visible=is_checked), gr.update(visible=is_checked), gr.update(visible=is_checked)
check_do_sample.select(fn=toggle_slide, inputs=check_do_sample, outputs=[num_beams, top_p, temperature])
with gr.Column(visible=True) as input_raws:
chat_state = gr.State([])
asr_msg = gr.State()
chatbot = gr.Chatbot(
elem_id="chatbot",
label='VideoChat',
avatar_images=[
"human.jpg", # 用户头像
"assistant.png", # AI头像
])
with gr.Row():
with gr.Column(scale=0.7):
text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False)
with gr.Column(scale=0.15, min_width=0):
run = gr.Button("💭Send", interactive=False)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("🔄Clear️", interactive=False)
with gr.Row():
examples = gr.Examples(
examples=[
["demo_videos/basketball.mp4", False, "Describe this video in detail."],
["demo_videos/cup1.mp4", False, "Describe this video in detail."],
["demo_videos/dog.mp4", False, "Describe this video in detail."],
],
inputs = [up_video, text_input],
outputs = [run, clear, up_video, text_input, upload_button, asr_msg],
fn=upload_video,
run_on_click=True
)
up_video.clear(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False)
upload_button.click(upload_video, [up_video], [run, clear, up_video, text_input, upload_button, asr_msg])
text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state]
).then(lambda: "", None, text_input)
run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state]
).then(lambda: "", None, text_input)
clear.click(clear_, None, [chatbot, chat_state])
restart.click(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False)
demo.launch(server_name='0.0.0.0',server_port=7864)