from transformers import PreTrainedModel from typing import List import torch from .configuration_flosmolv import FloSmolVConfig from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor class FloSmolV(PreTrainedModel): config_class = FloSmolVConfig device = "cuda" if torch.cuda.is_available() else "cpu" def __init__(self, config: FloSmolVConfig): super().__init__(config) self.florence2_model = AutoModelForCausalLM.from_pretrained( self.config.vision_config["_name_or_path"], trust_remote_code=True, ).eval().to(self.device) self.florence2_processor = AutoProcessor.from_pretrained(self.config.vision_config["_name_or_path"], trust_remote_code=True,) self.smollm_model = AutoModelForCausalLM.from_pretrained( self.config.llm_config["_name_or_path"], trust_remote_code=True, ).to(self.device) self.smollm_tokenizer = AutoTokenizer.from_pretrained(self.config.llm_config["_name_or_path"], trust_remote_code=True,) def forward(self, image, query: str): ## Vision prompt = "" vision_inpupt = self.florence2_processor(text=prompt, images=image, return_tensors="pt") generated_ids = self.florence2_model.generate( input_ids=vision_inpupt["input_ids"].to(torch.int64).to(self.device), pixel_values=vision_inpupt["pixel_values"].to(self.device), max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = self.florence2_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = self.florence2_processor.post_process_generation( generated_text, task=prompt, image_size=(image.width, image.height) ) ## LM messages = [{"role": "user", "content": f"You are an expert AI assistant. Based on the CONTENT, you should answer the QUESTION in short.\n\nCONTENT:{str(parsed_answer[prompt])}\n\nQUESTION:{str(query)}\n"}] input_text=self.smollm_tokenizer.apply_chat_template(messages, tokenize=False) llm_inputs = self.smollm_tokenizer.encode(input_text, return_tensors="pt").to(self.device) outputs = self.smollm_model.generate( llm_inputs, max_new_tokens=50, temperature=0.2, # top_p=0.9, do_sample=True, ) response = self.smollm_tokenizer.decode(outputs[0]) cleaned_text = response.split("assistant\n", 1)[-1].strip() return cleaned_text[:-10] if cleaned_text[-10:] == "<|im_end|>" else cleaned_text