flosmolv / modeling_flosmolv.py
dmedhi's picture
Upload FloSmolV
f879060 verified
raw
history blame
2.71 kB
from transformers import PreTrainedModel
from typing import List
import torch
from .configuration_flosmolv import FloSmolVConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
class FloSmolV(PreTrainedModel):
config_class = FloSmolVConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
def __init__(self, config: FloSmolVConfig):
super().__init__(config)
self.florence2_model = AutoModelForCausalLM.from_pretrained(
self.config.vision_config._name_or_path,
trust_remote_code=True,
).eval().to(self.device)
self.florence2_processor = AutoProcessor.from_pretrained(self.config.vision_config._name_or_path, trust_remote_code=True,)
self.smollm_model = AutoModelForCausalLM.from_pretrained(
self.config.llm_config._name_or_path,
trust_remote_code=True,
).to(self.device)
self.smollm_tokenizer = AutoTokenizer.from_pretrained(self.config.llm_config._name_or_path, trust_remote_code=True,)
def forward(self, image, query: str):
## Vision
prompt = "<MORE_DETAILED_CAPTION>"
vision_inpupt = self.florence2_processor(text=prompt, images=image, return_tensors="pt")
generated_ids = self.florence2_model.generate(
input_ids=vision_inpupt["input_ids"].to(torch.int64).to(self.device),
pixel_values=vision_inpupt["pixel_values"].to(self.device),
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = self.florence2_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = self.florence2_processor.post_process_generation(
generated_text,
task=prompt,
image_size=(image.width, image.height)
)
## LM
messages = [{"role": "user", "content": f"You are an expert AI assistant. Based on the CONTENT, you should answer the QUESTION in short.\n\nCONTENT:{str(parsed_answer[prompt])}\n\nQUESTION:{str(query)}\n"}]
input_text=self.smollm_tokenizer.apply_chat_template(messages, tokenize=False)
llm_inputs = self.smollm_tokenizer.encode(input_text, return_tensors="pt").to(self.device)
outputs = self.smollm_model.generate(
llm_inputs,
max_new_tokens=50,
temperature=0.2,
# top_p=0.9,
do_sample=True,
)
response = self.smollm_tokenizer.decode(outputs[0])
cleaned_text = response.split("assistant\n", 1)[-1].strip()
return cleaned_text[:-10] if cleaned_text[-10:] == "<|im_end|>" else cleaned_text