|
|
|
from modeling_xgenmm import * |
|
|
|
|
|
|
|
cfg = XGenMMConfig() |
|
model = XGenMMModelForConditionalGeneration(cfg) |
|
model = model.cuda() |
|
model = model.half() |
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoImageProcessor |
|
|
|
xgenmm_path = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5" |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
xgenmm_path, trust_remote_code=True, use_fast=False, legacy=False |
|
) |
|
image_processor = AutoImageProcessor.from_pretrained( |
|
xgenmm_path, trust_remote_code=True |
|
) |
|
tokenizer = model.update_special_tokens(tokenizer) |
|
|
|
model.eval() |
|
tokenizer.padding_side = "left" |
|
tokenizer.eos_token = "<|end|>" |
|
|
|
|
|
|
|
import numpy as np |
|
import torchvision |
|
|
|
import torchvision.io |
|
|
|
import math |
|
|
|
|
|
def sample_frames(vframes, num_frames): |
|
frame_indice = np.linspace(0, len(vframes) - 1, num_frames, dtype=int) |
|
video = vframes[frame_indice] |
|
video_list = [] |
|
for i in range(len(video)): |
|
video_list.append(torchvision.transforms.functional.to_pil_image(video[i])) |
|
return video_list |
|
|
|
|
|
def generate(messages, images): |
|
|
|
|
|
image_sizes = [image.size for image in images] |
|
|
|
|
|
image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float16) for img in images] |
|
|
|
image_tensor = torch.stack(image_tensor, dim=1) |
|
image_tensor = image_tensor.squeeze(2) |
|
inputs = {"pixel_values": image_tensor} |
|
|
|
full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n" |
|
for msg in messages: |
|
msg_str = "<|{role}|>\n{content}<|end|>\n".format( |
|
role=msg["role"], content=msg["content"] |
|
) |
|
full_conv += msg_str |
|
|
|
full_conv += "<|assistant|>\n" |
|
print(full_conv) |
|
language_inputs = tokenizer([full_conv], return_tensors="pt") |
|
for name, value in language_inputs.items(): |
|
language_inputs[name] = value.to(model.device) |
|
inputs.update(language_inputs) |
|
|
|
|
|
with torch.inference_mode(): |
|
generated_text = model.generate( |
|
**inputs, |
|
image_size=[image_sizes], |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
temperature=0.05, |
|
do_sample=False, |
|
max_new_tokens=1024, |
|
top_p=None, |
|
num_beams=1, |
|
) |
|
|
|
outputs = ( |
|
tokenizer.decode(generated_text[0], skip_special_tokens=True) |
|
.split("<|end|>")[0] |
|
.strip() |
|
) |
|
return outputs |
|
|
|
|
|
def predict(video_file, num_frames=8): |
|
vframes, _, _ = torchvision.io.read_video( |
|
filename=video_file, pts_unit="sec", output_format="TCHW" |
|
) |
|
total_frames = len(vframes) |
|
images = sample_frames(vframes, num_frames) |
|
|
|
prompt = "" |
|
prompt = prompt + "<image>\n" |
|
prompt = prompt + "Describe this video." |
|
messages = [{"role": "user", "content": prompt}] |
|
return generate(messages, images) |
|
|
|
|
|
|
|
import torch |
|
|
|
your_checkpoint_path = "" |
|
sd = torch.load(your_checkpoint_path) |
|
model.load_state_dict(sd) |
|
|
|
|
|
your_video_path = "" |
|
print( |
|
predict( |
|
your_video_path, |
|
num_frames = 16 |
|
) |
|
) |
|
|