Update README.md
Browse files
README.md
CHANGED
@@ -83,7 +83,9 @@ def get_image(url: str) -> Union[Image.Image, list]:
|
|
83 |
if "://" not in url: # Local file
|
84 |
content_type = get_content_type(url)
|
85 |
else: # Remote URL
|
86 |
-
content_type = requests.head(url, stream=True, verify=False).headers.get(
|
|
|
|
|
87 |
|
88 |
if "image" in content_type:
|
89 |
if "://" not in url: # Local file
|
@@ -114,11 +116,23 @@ def get_formatted_prompt(prompt: str) -> str:
|
|
114 |
|
115 |
def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
|
116 |
if isinstance(input_data, Image.Image):
|
117 |
-
vision_x =
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
elif isinstance(input_data, list): # list of video frames
|
119 |
-
vision_x =
|
|
|
|
|
|
|
|
|
120 |
else:
|
121 |
-
raise ValueError(
|
|
|
|
|
122 |
|
123 |
lang_x = model.text_tokenizer(
|
124 |
[
|
@@ -148,36 +162,39 @@ def get_response(input_data, prompt: str, model=None, image_processor=None) -> s
|
|
148 |
)
|
149 |
return parsed_output
|
150 |
|
|
|
151 |
if __name__ == "__main__":
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
183 |
```
|
|
|
83 |
if "://" not in url: # Local file
|
84 |
content_type = get_content_type(url)
|
85 |
else: # Remote URL
|
86 |
+
content_type = requests.head(url, stream=True, verify=False).headers.get(
|
87 |
+
"Content-Type"
|
88 |
+
)
|
89 |
|
90 |
if "image" in content_type:
|
91 |
if "://" not in url: # Local file
|
|
|
116 |
|
117 |
def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
|
118 |
if isinstance(input_data, Image.Image):
|
119 |
+
vision_x = (
|
120 |
+
image_processor.preprocess([input_data], return_tensors="pt")[
|
121 |
+
"pixel_values"
|
122 |
+
]
|
123 |
+
.unsqueeze(1)
|
124 |
+
.unsqueeze(0)
|
125 |
+
)
|
126 |
elif isinstance(input_data, list): # list of video frames
|
127 |
+
vision_x = (
|
128 |
+
image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"]
|
129 |
+
.unsqueeze(1)
|
130 |
+
.unsqueeze(0)
|
131 |
+
)
|
132 |
else:
|
133 |
+
raise ValueError(
|
134 |
+
"Invalid input data. Expected PIL Image or list of video frames."
|
135 |
+
)
|
136 |
|
137 |
lang_x = model.text_tokenizer(
|
138 |
[
|
|
|
162 |
)
|
163 |
return parsed_output
|
164 |
|
165 |
+
|
166 |
if __name__ == "__main__":
|
167 |
+
# ------------------- Main Function -------------------
|
168 |
+
load_bit = "fp16"
|
169 |
+
if load_bit == "fp16":
|
170 |
+
precision = {"torch_dtype": torch.float16}
|
171 |
+
elif load_bit == "bf16":
|
172 |
+
precision = {"torch_dtype": torch.bfloat16}
|
173 |
+
elif load_bit == "fp32":
|
174 |
+
precision = {"torch_dtype": torch.float32}
|
175 |
+
|
176 |
+
# This model version is trained on MIMIC-IT DC dataset.
|
177 |
+
model = OtterForConditionalGeneration.from_pretrained(
|
178 |
+
"luodian/otter-9b-dc-hf", device_map="auto", **precision
|
179 |
+
)
|
180 |
+
model.text_tokenizer.padding_side = "left"
|
181 |
+
tokenizer = model.text_tokenizer
|
182 |
+
image_processor = transformers.CLIPImageProcessor()
|
183 |
+
model.eval()
|
184 |
+
|
185 |
+
while True:
|
186 |
+
video_url = "demo.mp4" # Replace with the path to your video file
|
187 |
+
|
188 |
+
frames_list = get_image(video_url)
|
189 |
+
|
190 |
+
prompts_input = input("Enter prompts (comma-separated): ")
|
191 |
+
prompts = [prompt.strip() for prompt in prompts_input.split(",")]
|
192 |
+
|
193 |
+
for prompt in prompts:
|
194 |
+
print(f"\nPrompt: {prompt}")
|
195 |
+
response = get_response(frames_list, prompt, model, image_processor)
|
196 |
+
print(f"Response: {response}")
|
197 |
+
|
198 |
+
if prompts_input.lower() == "quit":
|
199 |
+
break
|
200 |
```
|