|
from transformers import PreTrainedModel |
|
from typing import List |
|
import torch |
|
from .configuration_flosmolv import FloSmolVConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor |
|
|
|
class FloSmolV(PreTrainedModel): |
|
config_class = FloSmolVConfig |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def __init__(self, config: FloSmolVConfig): |
|
super().__init__(config) |
|
self.florence2_model = AutoModelForCausalLM.from_pretrained( |
|
self.config.vision_config._name_or_path, |
|
trust_remote_code=True, |
|
).eval().to(self.device) |
|
self.florence2_processor = AutoProcessor.from_pretrained(self.config.vision_config._name_or_path, trust_remote_code=True,) |
|
self.smollm_model = AutoModelForCausalLM.from_pretrained( |
|
self.config.llm_config._name_or_path, |
|
trust_remote_code=True, |
|
).to(self.device) |
|
self.smollm_tokenizer = AutoTokenizer.from_pretrained(self.config.llm_config._name_or_path, trust_remote_code=True,) |
|
|
|
def forward(self, image, query: str): |
|
|
|
prompt = "<MORE_DETAILED_CAPTION>" |
|
vision_inpupt = self.florence2_processor(text=prompt, images=image, return_tensors="pt") |
|
generated_ids = self.florence2_model.generate( |
|
input_ids=vision_inpupt["input_ids"].to(torch.int64).to(self.device), |
|
pixel_values=vision_inpupt["pixel_values"].to(self.device), |
|
max_new_tokens=1024, |
|
early_stopping=False, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
generated_text = self.florence2_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = self.florence2_processor.post_process_generation( |
|
generated_text, |
|
task=prompt, |
|
image_size=(image.width, image.height) |
|
) |
|
|
|
messages = [{"role": "user", "content": f"You are an expert AI assistant. Based on the CONTENT, you should answer the QUESTION in short.\n\nCONTENT:{str(parsed_answer[prompt])}\n\nQUESTION:{str(query)}\n"}] |
|
input_text=self.smollm_tokenizer.apply_chat_template(messages, tokenize=False) |
|
llm_inputs = self.smollm_tokenizer.encode(input_text, return_tensors="pt").to(self.device) |
|
outputs = self.smollm_model.generate( |
|
llm_inputs, |
|
max_new_tokens=50, |
|
temperature=0.2, |
|
|
|
do_sample=True, |
|
) |
|
response = self.smollm_tokenizer.decode(outputs[0]) |
|
cleaned_text = response.split("assistant\n", 1)[-1].strip() |
|
return cleaned_text[:-10] if cleaned_text[-10:] == "<|im_end|>" else cleaned_text |