File size: 2,712 Bytes
04e1d9d
 
 
 
 
 
 
 
 
f879060
 
04e1d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from transformers import PreTrainedModel
from typing import List
import torch 
from .configuration_flosmolv import FloSmolVConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor

class FloSmolV(PreTrainedModel):
    config_class = FloSmolVConfig
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def __init__(self, config: FloSmolVConfig):
        super().__init__(config)
        self.florence2_model = AutoModelForCausalLM.from_pretrained(
            self.config.vision_config._name_or_path,
            trust_remote_code=True,
        ).eval().to(self.device)
        self.florence2_processor = AutoProcessor.from_pretrained(self.config.vision_config._name_or_path, trust_remote_code=True,)
        self.smollm_model = AutoModelForCausalLM.from_pretrained(
            self.config.llm_config._name_or_path,
            trust_remote_code=True,
        ).to(self.device)
        self.smollm_tokenizer = AutoTokenizer.from_pretrained(self.config.llm_config._name_or_path, trust_remote_code=True,)
      
    def forward(self, image, query: str):
      ## Vision
        prompt = "<MORE_DETAILED_CAPTION>"
        vision_inpupt = self.florence2_processor(text=prompt, images=image, return_tensors="pt")
        generated_ids = self.florence2_model.generate(
            input_ids=vision_inpupt["input_ids"].to(torch.int64).to(self.device),
            pixel_values=vision_inpupt["pixel_values"].to(self.device),
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
        )
        generated_text = self.florence2_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = self.florence2_processor.post_process_generation(
            generated_text,
            task=prompt,
            image_size=(image.width, image.height)
        )
        ## LM
        messages = [{"role": "user", "content": f"You are an expert AI assistant. Based on the CONTENT, you should answer the QUESTION in short.\n\nCONTENT:{str(parsed_answer[prompt])}\n\nQUESTION:{str(query)}\n"}]
        input_text=self.smollm_tokenizer.apply_chat_template(messages, tokenize=False)
        llm_inputs = self.smollm_tokenizer.encode(input_text, return_tensors="pt").to(self.device)
        outputs = self.smollm_model.generate(
            llm_inputs, 
            max_new_tokens=50, 
            temperature=0.2, 
            # top_p=0.9, 
            do_sample=True,
        )
        response = self.smollm_tokenizer.decode(outputs[0])
        cleaned_text = response.split("assistant\n", 1)[-1].strip()
        return cleaned_text[:-10] if cleaned_text[-10:] == "<|im_end|>" else cleaned_text