# %% from modeling_xgenmm import * # %% cfg = XGenMMConfig() model = XGenMMModelForConditionalGeneration(cfg) model = model.cuda() model = model.half() # %% from transformers import AutoTokenizer, AutoImageProcessor xgenmm_path = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5" tokenizer = AutoTokenizer.from_pretrained( xgenmm_path, trust_remote_code=True, use_fast=False, legacy=False ) image_processor = AutoImageProcessor.from_pretrained( xgenmm_path, trust_remote_code=True ) tokenizer = model.update_special_tokens(tokenizer) # model = model.to("cuda") model.eval() tokenizer.padding_side = "left" tokenizer.eos_token = "<|end|>" # %% import numpy as np import torchvision import torchvision.io import math def sample_frames(vframes, num_frames): frame_indice = np.linspace(0, len(vframes) - 1, num_frames, dtype=int) video = vframes[frame_indice] video_list = [] for i in range(len(video)): video_list.append(torchvision.transforms.functional.to_pil_image(video[i])) return video_list def generate(messages, images): # img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images] # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list] image_sizes = [image.size for image in images] # Similar operation in model_worker.py image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float16) for img in images] image_tensor = torch.stack(image_tensor, dim=1) image_tensor = image_tensor.squeeze(2) inputs = {"pixel_values": image_tensor} full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n" for msg in messages: msg_str = "<|{role}|>\n{content}<|end|>\n".format( role=msg["role"], content=msg["content"] ) full_conv += msg_str full_conv += "<|assistant|>\n" print(full_conv) language_inputs = tokenizer([full_conv], return_tensors="pt") for name, value in language_inputs.items(): language_inputs[name] = value.to(model.device) inputs.update(language_inputs) # print(inputs) with torch.inference_mode(): generated_text = model.generate( **inputs, image_size=[image_sizes], pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, temperature=0.05, do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1, ) outputs = ( tokenizer.decode(generated_text[0], skip_special_tokens=True) .split("<|end|>")[0] .strip() ) return outputs def predict(video_file, num_frames=8): vframes, _, _ = torchvision.io.read_video( filename=video_file, pts_unit="sec", output_format="TCHW" ) total_frames = len(vframes) images = sample_frames(vframes, num_frames) prompt = "" prompt = prompt + "\n" prompt = prompt + "Describe this video." messages = [{"role": "user", "content": prompt}] return generate(messages, images) # %% import torch your_checkpoint_path = "" sd = torch.load(your_checkpoint_path) model.load_state_dict(sd) # %% your_video_path = "" print( predict( your_video_path, num_frames = 16 ) )