Ryukijano commited on
Commit
84ce5d1
Β·
verified Β·
1 Parent(s): d3680ea

Update vision_llm.py

Browse files
Files changed (1) hide show
  1. vision_llm.py +3 -3
vision_llm.py CHANGED
@@ -1,14 +1,14 @@
1
- # timeforge/vision_llm.py
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
  from PIL import Image
 
5
  class VisionLLM:
6
- def __init__(self, device="cuda", model_id="microsoft/kosmos-2"):
7
  self.device = device
8
  self.processor = AutoProcessor.from_pretrained(model_id)
9
  self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device)
10
 
11
-
12
  def describe_images(self, images, prompt="", max_length=128):
13
  if isinstance(images, list):
14
  images = [img.convert("RGB") for img in images]
 
1
+ # vision_llm.py
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
  from PIL import Image
5
+
6
  class VisionLLM:
7
+ def __init__(self, device="cuda", model_id="google/paligemma2-3b-pt-224"): # Corrected model ID
8
  self.device = device
9
  self.processor = AutoProcessor.from_pretrained(model_id)
10
  self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device)
11
 
 
12
  def describe_images(self, images, prompt="", max_length=128):
13
  if isinstance(images, list):
14
  images = [img.convert("RGB") for img in images]