File size: 5,824 Bytes
5657e8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import os
from ChatUniVi.constants import *
from ChatUniVi.conversation import conv_templates, SeparatorStyle
from ChatUniVi.model.builder import load_pretrained_model
from ChatUniVi.utils import disable_torch_init
from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
from decord import VideoReader, cpu
import numpy as np
def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224, video_framerate=1, s=None, e=None):
# speed up video decode via decord.
if s is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = start_time if start_time >= 0. else 0.
end_time = end_time if end_time >= 0. else 0.
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# T x 3 x H x W
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > max_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
patch_images = torch.stack([image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images])
slice_len = patch_images.shape[0]
return patch_images, slice_len
else:
print("video path: {} error.".format(video_path))
if __name__ == '__main__':
# Model Parameter
model_path = "Chat-UniVi/Chat-UniVi" # or "Chat-UniVi/Chat-UniVi-13B"
video_path = ${video_path}
# The number of visual tokens varies with the length of the video. "max_frames" is the maximum number of frames.
# When the video is long, we will uniformly downsample the video to meet the frames when equal to the "max_frames".
max_frames = 100
# The number of frames retained per second in the video.
video_framerate = 1
# Input Text
qs = "Describe the video."
# Sampling Parameter
conv_mode = "simple"
temperature = 0.2
top_p = None
num_beams = 1
disable_torch_init()
model_path = os.path.expanduser(model_path)
model_name = "ChatUniVi"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
image_processor = vision_tower.image_processor
if model.config.config["use_cluster"]:
for n, m in model.named_modules():
m = m.to(dtype=torch.bfloat16)
# Check if the video exists
if video_path is not None:
video_frames, slice_len = _get_rawvideo_dec(video_path, image_processor, max_frames=max_frames, video_framerate=video_framerate)
cur_prompt = qs
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=video_frames.half().cuda(),
do_sample=True,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])
output_ids = output_ids.sequences
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
print(outputs) |