xgen-mm-vid-phi3-mini-r-v1.5-128tokens-8frames / xgen-mm-vid-inference-script_hf.py
michaelryoo's picture
Update xgen-mm-vid-inference-script_hf.py
b072331 verified
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, LogitsProcessor
import torch
model_name_or_path = "Salesforce/xgen-mm-vid-phi3-mini-r-v1.5-128tokens-8frames"
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False, legacy=False)
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = model.update_special_tokens(tokenizer)
model = model.to('cuda')
model.eval()
tokenizer.padding_side = "left"
tokenizer.eos_token = "<|end|>"
# %%
import numpy as np
import torchvision
import torchvision.io
import math
def sample_frames(vframes, num_frames):
frame_indice = np.linspace(int(num_frames/2), len(vframes) - int(num_frames/2), num_frames, dtype=int)
video = vframes[frame_indice]
video_list = []
for i in range(len(video)):
video_list.append(torchvision.transforms.functional.to_pil_image(video[i]))
return video_list
def generate(messages, images):
# img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images]
# images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
image_sizes = [image.size for image in images]
# Similar operation in model_worker.py
image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float32) for img in images]
image_tensor = torch.stack(image_tensor, dim=1)
image_tensor = image_tensor.squeeze(2)
inputs = {"pixel_values": image_tensor}
full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
for msg in messages:
msg_str = "<|{role}|>\n{content}<|end|>\n".format(
role=msg["role"], content=msg["content"]
)
full_conv += msg_str
full_conv += "<|assistant|>\n"
print(full_conv)
language_inputs = tokenizer([full_conv], return_tensors="pt")
for name, value in language_inputs.items():
language_inputs[name] = value.to(model.device)
inputs.update(language_inputs)
# print(inputs)
with torch.inference_mode():
generated_text = model.generate(
**inputs,
image_size=[image_sizes],
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.05,
do_sample=False,
max_new_tokens=1024,
top_p=None,
num_beams=1,
)
outputs = (
tokenizer.decode(generated_text[0], skip_special_tokens=True)
.split("<|end|>")[0]
.strip()
)
return outputs
def predict(video_file, num_frames=8):
vframes, _, _ = torchvision.io.read_video(
filename=video_file, pts_unit="sec", output_format="TCHW"
)
total_frames = len(vframes)
images = sample_frames(vframes, num_frames)
prompt = ""
prompt = prompt + "<image>\n"
# prompt = prompt + "What's the main gist of the video ?"
prompt = prompt + "Please describe the primary object or subject in the video, capturing their attributes, actions, positions, and movements."
messages = [{"role": "user", "content": prompt}]
return generate(messages, images)
# %%
video_path = ""
print(
predict(
video_path,
num_frames = 8
)
)
# %%