from fastapi import FastAPI, File, UploadFile, Form, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse import torch import os from ChatUniVi.constants import * from ChatUniVi.conversation import conv_templates, SeparatorStyle from ChatUniVi.model.builder import load_pretrained_model from ChatUniVi.utils import disable_torch_init from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image from decord import VideoReader, cpu import numpy as np import asyncio app = FastAPI() # Global variables to store the model components model = None tokenizer = None image_processor = None loading_progress = 0 def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224, video_framerate=1, s=None, e=None): if s is None: start_time, end_time = None, None else: start_time = int(s) end_time = int(e) start_time = start_time if start_time >= 0. else 0. end_time = end_time if end_time >= 0. else 0. if start_time > end_time: start_time, end_time = end_time, start_time elif start_time == end_time: end_time = start_time + 1 if os.path.exists(video_path): vreader = VideoReader(video_path, ctx=cpu(0)) else: print(video_path) raise FileNotFoundError fps = vreader.get_avg_fps() f_start = 0 if start_time is None else int(start_time * fps) f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) num_frames = f_end - f_start + 1 if num_frames > 0: sample_fps = int(video_framerate) t_stride = int(round(float(fps) / sample_fps)) all_pos = list(range(f_start, f_end + 1, t_stride)) if len(all_pos) > max_frames: sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)] else: sample_pos = all_pos patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] patch_images = torch.stack([image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]) slice_len = patch_images.shape[0] return patch_images, slice_len else: print("video path: {} error.") @app.on_event("startup") async def load_model(): global model, tokenizer, image_processor, loading_progress disable_torch_init() model_path = "/home/manish/Chat-UniVi/model/Chat-UniVi" model_name = "ChatUniVi" loading_progress = 10 await asyncio.sleep(1) # Simulating progress step tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name) loading_progress = 50 await asyncio.sleep(1) # Simulating progress step mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) loading_progress = 70 await asyncio.sleep(1) # Simulating progress step vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() image_processor = vision_tower.image_processor loading_progress = 100 @app.post("/process") async def process_video(question: str = Form(...), video: UploadFile = File(...)): try: video_path = f"temp_{video.filename}" with open(video_path, "wb") as f: f.write(video.file.read()) max_frames = 100 video_framerate = 1 video_frames, slice_len = _get_rawvideo_dec(video_path, image_processor, max_frames=max_frames, video_framerate=video_framerate) if model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + question else: qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + question conv = conv_templates["simple"].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=video_frames.half().cuda(), do_sample=True, temperature=0.2, top_p=None, num_beams=1, output_scores=True, return_dict_in_generate=True, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria] ) output_ids = output_ids.sequences input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() return {"answer": outputs} except Exception as e: return {"error": str(e)} @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() global loading_progress try: while loading_progress < 100: await websocket.send_json({"progress": loading_progress}) await asyncio.sleep(1) await websocket.send_json({"progress": loading_progress, "status": "Model Loaded"}) except WebSocketDisconnect: print("WebSocket disconnected") finally: await websocket.close() ### HTML Page @app.get("/", response_class=HTMLResponse) async def get(): return """ Video Question Answering

Video Question Answering

"""