TimeForge / vision_llm.py
Ryukijano's picture
Create vision_llm.py
38adfc7 verified
raw
history blame
1.21 kB
# timeforge/vision_llm.py
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
from PIL import Image
class VisionLLM:
def __init__(self, device="cuda", model_id="microsoft/kosmos-2"):
self.device = device
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device)
def describe_images(self, images, prompt="", max_length=128):
if isinstance(images, list):
images = [img.convert("RGB") for img in images]
else:
images = images.convert("RGB")
inputs = self.processor(images=images, text=prompt, return_tensors="pt").to(self.device)
with torch.no_grad(), torch.autocast("cuda"):
outputs = self.model.generate(**inputs, max_length=max_length)
descriptions = self.processor.batch_decode(outputs, skip_special_tokens=True)
return descriptions
if __name__ == "__main__":
vllm = VisionLLM()
images = [Image.open("multi_view_output.png")]
prompt = "Describe the objects in the image"
descriptions = vllm.describe_images(images, prompt)
print(descriptions)