Update README.md
Browse files
README.md
CHANGED
@@ -78,6 +78,11 @@ pip install git+https://github.com/TIGER-AI-Lab/MantisScore.git
|
|
78 |
```python
|
79 |
import av
|
80 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
81 |
def _read_video_pyav(
|
82 |
frame_paths:List[str],
|
83 |
max_frames:int,
|
@@ -94,6 +99,7 @@ def _read_video_pyav(
|
|
94 |
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
95 |
|
96 |
MAX_NUM_FRAMES=16
|
|
|
97 |
REGRESSION_QUERY_PROMPT = """
|
98 |
Suppose you are an expert in judging and evaluating the quality of AI-generated videos,
|
99 |
please watch the following frames of a given video and see the text prompt for generating the video,
|
@@ -119,6 +125,12 @@ all the frames of video are as follows:
|
|
119 |
"""
|
120 |
|
121 |
video_path="examples/video1.mp4"
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# sample uniformly 8 frames from the video
|
124 |
container = av.open(video_path)
|
@@ -129,7 +141,7 @@ else:
|
|
129 |
indices = np.arange(total_frames)
|
130 |
|
131 |
frames = [Image.fromarray(x) for x in _read_video_pyav(container, indices)]
|
132 |
-
eval_prompt =
|
133 |
num_image_token = eval_prompt.count("<image>")
|
134 |
if num_image_token < len(frames):
|
135 |
eval_prompt += "<image> " * (len(frames) - num_image_token)
|
|
|
78 |
```python
|
79 |
import av
|
80 |
import numpy as np
|
81 |
+
from typing import List
|
82 |
+
import torch
|
83 |
+
from transformers import AutoProcessor
|
84 |
+
from models.idefics2 import Idefics2ForSequenceClassification
|
85 |
+
|
86 |
def _read_video_pyav(
|
87 |
frame_paths:List[str],
|
88 |
max_frames:int,
|
|
|
99 |
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
100 |
|
101 |
MAX_NUM_FRAMES=16
|
102 |
+
ROUND_DIGIT=4
|
103 |
REGRESSION_QUERY_PROMPT = """
|
104 |
Suppose you are an expert in judging and evaluating the quality of AI-generated videos,
|
105 |
please watch the following frames of a given video and see the text prompt for generating the video,
|
|
|
125 |
"""
|
126 |
|
127 |
video_path="examples/video1.mp4"
|
128 |
+
video_prompt=""
|
129 |
+
|
130 |
+
processor = AutoProcessor.from_pretrained(f"TIGER-Lab/MantisScore",torch_dtype=torch.bfloat16)
|
131 |
+
model = Idefics2ForSequenceClassification.from_pretrained(f"TIGER-Lab/MantisScore",torch_dtype=torch.bfloat16).eval()
|
132 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
133 |
+
model.to(device)
|
134 |
|
135 |
# sample uniformly 8 frames from the video
|
136 |
container = av.open(video_path)
|
|
|
141 |
indices = np.arange(total_frames)
|
142 |
|
143 |
frames = [Image.fromarray(x) for x in _read_video_pyav(container, indices)]
|
144 |
+
eval_prompt = REGRESSION_QUERY_PROMPT.format(text_prompt=video_prompt)
|
145 |
num_image_token = eval_prompt.count("<image>")
|
146 |
if num_image_token < len(frames):
|
147 |
eval_prompt += "<image> " * (len(frames) - num_image_token)
|