luodian commited on
Commit
2079d37
·
1 Parent(s): 9f5457a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -35
README.md CHANGED
@@ -83,7 +83,9 @@ def get_image(url: str) -> Union[Image.Image, list]:
83
  if "://" not in url: # Local file
84
  content_type = get_content_type(url)
85
  else: # Remote URL
86
- content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
 
 
87
 
88
  if "image" in content_type:
89
  if "://" not in url: # Local file
@@ -114,11 +116,23 @@ def get_formatted_prompt(prompt: str) -> str:
114
 
115
  def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
116
  if isinstance(input_data, Image.Image):
117
- vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
 
 
 
 
 
 
118
  elif isinstance(input_data, list): # list of video frames
119
- vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
 
 
 
 
120
  else:
121
- raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
 
 
122
 
123
  lang_x = model.text_tokenizer(
124
  [
@@ -148,36 +162,39 @@ def get_response(input_data, prompt: str, model=None, image_processor=None) -> s
148
  )
149
  return parsed_output
150
 
 
151
  if __name__ == "__main__":
152
- # ------------------- Main Function -------------------
153
- load_bit = "fp16"
154
- if load_bit == "fp16":
155
- precision = {"torch_dtype": torch.float16}
156
- elif load_bit == "bf16":
157
- precision = {"torch_dtype": torch.bfloat16}
158
- elif load_bit == "fp32":
159
- precision = {"torch_dtype": torch.float32}
160
-
161
- # This model version is trained on MIMIC-IT DC dataset.
162
- model = OtterForConditionalGeneration.from_pretrained("luodian/otter-9b-dc-hf", device_map="auto", **precision)
163
- model.text_tokenizer.padding_side = "left"
164
- tokenizer = model.text_tokenizer
165
- image_processor = transformers.CLIPImageProcessor()
166
- model.eval()
167
-
168
- while True:
169
- video_url = "demo.mp4" # Replace with the path to your video file
170
-
171
- frames_list = get_image(video_url)
172
-
173
- prompts_input = input("Enter prompts (comma-separated): ")
174
- prompts = [prompt.strip() for prompt in prompts_input.split(",")]
175
-
176
- for prompt in prompts:
177
- print(f"\nPrompt: {prompt}")
178
- response = get_response(frames_list, prompt, model, image_processor)
179
- print(f"Response: {response}")
180
-
181
- if prompts_input.lower() == "quit":
182
- break
 
 
183
  ```
 
83
  if "://" not in url: # Local file
84
  content_type = get_content_type(url)
85
  else: # Remote URL
86
+ content_type = requests.head(url, stream=True, verify=False).headers.get(
87
+ "Content-Type"
88
+ )
89
 
90
  if "image" in content_type:
91
  if "://" not in url: # Local file
 
116
 
117
  def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
118
  if isinstance(input_data, Image.Image):
119
+ vision_x = (
120
+ image_processor.preprocess([input_data], return_tensors="pt")[
121
+ "pixel_values"
122
+ ]
123
+ .unsqueeze(1)
124
+ .unsqueeze(0)
125
+ )
126
  elif isinstance(input_data, list): # list of video frames
127
+ vision_x = (
128
+ image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"]
129
+ .unsqueeze(1)
130
+ .unsqueeze(0)
131
+ )
132
  else:
133
+ raise ValueError(
134
+ "Invalid input data. Expected PIL Image or list of video frames."
135
+ )
136
 
137
  lang_x = model.text_tokenizer(
138
  [
 
162
  )
163
  return parsed_output
164
 
165
+
166
  if __name__ == "__main__":
167
+ # ------------------- Main Function -------------------
168
+ load_bit = "fp16"
169
+ if load_bit == "fp16":
170
+ precision = {"torch_dtype": torch.float16}
171
+ elif load_bit == "bf16":
172
+ precision = {"torch_dtype": torch.bfloat16}
173
+ elif load_bit == "fp32":
174
+ precision = {"torch_dtype": torch.float32}
175
+
176
+ # This model version is trained on MIMIC-IT DC dataset.
177
+ model = OtterForConditionalGeneration.from_pretrained(
178
+ "luodian/otter-9b-dc-hf", device_map="auto", **precision
179
+ )
180
+ model.text_tokenizer.padding_side = "left"
181
+ tokenizer = model.text_tokenizer
182
+ image_processor = transformers.CLIPImageProcessor()
183
+ model.eval()
184
+
185
+ while True:
186
+ video_url = "demo.mp4" # Replace with the path to your video file
187
+
188
+ frames_list = get_image(video_url)
189
+
190
+ prompts_input = input("Enter prompts (comma-separated): ")
191
+ prompts = [prompt.strip() for prompt in prompts_input.split(",")]
192
+
193
+ for prompt in prompts:
194
+ print(f"\nPrompt: {prompt}")
195
+ response = get_response(frames_list, prompt, model, image_processor)
196
+ print(f"Response: {response}")
197
+
198
+ if prompts_input.lower() == "quit":
199
+ break
200
  ```