|
import gradio as gr |
|
import torch |
|
import os |
|
from ChatUniVi.constants import * |
|
from ChatUniVi.conversation import conv_templates, SeparatorStyle |
|
from ChatUniVi.model.builder import load_pretrained_model |
|
from ChatUniVi.utils import disable_torch_init |
|
from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria |
|
from PIL import Image |
|
from decord import VideoReader, cpu |
|
import numpy as np |
|
from threading import Thread |
|
|
|
import time |
|
|
|
def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224, video_framerate=1, s=None, e=None): |
|
if s is None: |
|
start_time, end_time = None, None |
|
else: |
|
start_time = int(s) |
|
end_time = int(e) |
|
start_time = start_time if start_time >= 0. else 0. |
|
end_time = end_time if end_time >= 0. else 0. |
|
if start_time > end_time: |
|
start_time, end_time = end_time, start_time |
|
elif start_time == end_time: |
|
end_time = start_time + 1 |
|
|
|
if os.path.exists(video_path): |
|
vreader = VideoReader(video_path, ctx=cpu(0)) |
|
else: |
|
print(video_path) |
|
raise FileNotFoundError |
|
|
|
fps = vreader.get_avg_fps() |
|
f_start = 0 if start_time is None else int(start_time * fps) |
|
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) |
|
num_frames = f_end - f_start + 1 |
|
if num_frames > 0: |
|
sample_fps = int(video_framerate) |
|
t_stride = int(round(float(fps) / sample_fps)) |
|
|
|
all_pos = list(range(f_start, f_end + 1, t_stride)) |
|
if len(all_pos) > max_frames: |
|
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)] |
|
else: |
|
sample_pos = all_pos |
|
|
|
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] |
|
|
|
patch_images = torch.stack([image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]) |
|
slice_len = patch_images.shape[0] |
|
|
|
return patch_images, slice_len |
|
else: |
|
print("video path: {} error.") |
|
|
|
|
|
def bot_streaming(message, history): |
|
try: |
|
print("Received message:", message) |
|
video_path = None |
|
|
|
if isinstance(message, dict) and "files" in message and message["files"]: |
|
video_path = message["files"][-1] |
|
|
|
if not video_path: |
|
for hist in history: |
|
if isinstance(hist[0], tuple): |
|
video_path = hist[0][0] |
|
break |
|
|
|
if not video_path: |
|
yield "You need to upload a video for ChatUniVi to work." |
|
return |
|
|
|
max_frames = 100 |
|
video_framerate = 1 |
|
qs = message['text'] |
|
|
|
disable_torch_init() |
|
model_path = "/home/manish/Chat-UniVi/model/Chat-UniVi" |
|
model_name = "ChatUniVi" |
|
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name) |
|
|
|
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) |
|
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) |
|
if mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
if mm_use_im_start_end: |
|
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
vision_tower = model.get_vision_tower() |
|
if not vision_tower.is_loaded: |
|
vision_tower.load_model() |
|
image_processor = vision_tower.image_processor |
|
|
|
video_frames, slice_len = _get_rawvideo_dec(video_path, image_processor, max_frames=max_frames, video_framerate=video_framerate) |
|
|
|
if model.config.mm_use_im_start_end: |
|
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs |
|
else: |
|
qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs |
|
|
|
conv = conv_templates["simple"].copy() |
|
conv.append_message(conv.roles[0], qs) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=video_frames.half().cuda(), |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=None, |
|
num_beams=1, |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
max_new_tokens=1024, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria] |
|
) |
|
|
|
output_ids = output_ids.sequences |
|
input_token_len = input_ids.shape[1] |
|
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() |
|
if n_diff_input_output > 0: |
|
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') |
|
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] |
|
outputs = outputs.strip() |
|
if outputs.endswith(stop_str): |
|
outputs = outputs[:-len(stop_str)] |
|
outputs = outputs.strip() |
|
yield outputs |
|
|
|
except Exception as e: |
|
yield f"Error: {str(e)}" |
|
|
|
demo = gr.ChatInterface(fn=bot_streaming, title="Super Rapid Annotator", examples=[{"text": "For each question, analyze the given video carefully, detect objects, different colors and base your answers on the observations made. Describe the subject's clothing and different the objects the subject is attached to. Examine the subject’s both right and left hands in the video to check if they are holding anything like microphone, book, paper (white color), object or any electronic device, try segmentations and decide if the hands are free or not. Evaluate the subject’s body posture and movement within the video. Are they standing upright with both feet planted firmly on the ground? If so, they are standing. If they seem to be seated, they are seated. Assess the surroundings behind the subject in the video. Do they seem to interact with any visible screens by actively engaged in the presentation, such as laptops, TVs, or digital billboards? If yes, then they are interacting with a screen. If not, they are not interacting with a screen. Consider the broader environmental context shown in the video’s background. Are there signs of an open-air space, like greenery, structures, or people passing by? If so, it’s an outdoor setting. If the setting looks confined with furniture, walls, or home decorations, it’s an indoor environment. By taking these factors into account when watching the video, please answer the questions accurately.", "files":["./sample_video.mp4"]}], |
|
description="Upload a video and start chatting about it, or simply try the example below. If you don't upload a video, you will receive an error.", |
|
stop_btn="Stop Generation", multimodal=True) |
|
demo.launch(debug=True, share=True) |
|
|