import base64 import json import os import math from io import BytesIO from typing import Any, Dict, List, Literal, Optional, Union import requests import torch from PIL import Image from torch import nn from transformers import AutoProcessor, Qwen2VLForConditionalGeneration class Transformer(nn.Module): save_in_root: bool = True def __init__( self, model_name_or_path: str = 'llamaindex/vdr-2b-multi-v1', processor_name_or_path: Optional[str] = None, max_pixels: int = 768 * 28 * 28, min_pixels: int = 1 * 28 * 28, dimension: int = 2048, max_seq_length: Optional[int] = None, model_args: Optional[Dict[str, Any]] = None, processor_args: Optional[Dict[str, Any]] = None, cache_dir: Optional[str] = None, device: str = 'cuda:0', backend: Literal['torch', 'onnx', 'openvino'] = 'torch', **kwargs, ) -> None: super(Transformer, self).__init__() if backend != 'torch': raise ValueError( f'Backend \'{backend}\' is not supported, please use \'torch\' instead' ) self.device = device self.dimension = dimension self.max_pixels = max_pixels self.min_pixels = min_pixels self.max_seq_length = max_seq_length # Handle args model_kwargs = model_args or {} model_kwargs.update(kwargs) processor_kwargs = processor_args or {} processor_kwargs.update({ 'min_pixels': min_pixels, 'max_pixels': max_pixels, 'cache_dir': cache_dir }) # Initialize model try: self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_name_or_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map=device, cache_dir=cache_dir, **model_kwargs ).eval() except (ImportError, ValueError) as e: print(f"Flash attention not available, falling back to default attention: {e}") self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, cache_dir=cache_dir, **model_kwargs ).eval() # Initialize processor self.processor = AutoProcessor.from_pretrained( processor_name_or_path or model_name_or_path, **processor_kwargs ) # Set padding sides self.model.padding_side = "left" self.processor.tokenizer.padding_side = "left" # Store prompts self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>" self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>" # Try to infer max_seq_length if not provided if self.max_seq_length is None: if ( hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings') and hasattr(self.processor.tokenizer, 'model_max_length') ): self.max_seq_length = min( self.model.config.max_position_embeddings, self.processor.tokenizer.model_max_length, ) def _smart_resize(self, height: int, width: int) -> tuple[int, int]: h_bar = max(28, self._round_by_factor(height, 28)) w_bar = max(28, self._round_by_factor(width, 28)) if h_bar * w_bar > self.max_pixels: beta = math.sqrt((height * width) / self.max_pixels) h_bar = self._floor_by_factor(height / beta, 28) w_bar = self._floor_by_factor(width / beta, 28) elif h_bar * w_bar < self.min_pixels: beta = math.sqrt(self.min_pixels / (height * width)) h_bar = self._ceil_by_factor(height * beta, 28) w_bar = self._ceil_by_factor(width * beta, 28) return w_bar, h_bar @staticmethod def _round_by_factor(number: float, factor: int) -> int: return round(number / factor) * factor @staticmethod def _ceil_by_factor(number: float, factor: int) -> int: return math.ceil(number / factor) * factor @staticmethod def _floor_by_factor(number: float, factor: int) -> int: return math.floor(number / factor) * factor def _resize_image(self, image: Image.Image) -> Image.Image: new_size = self._smart_resize(image.height, image.width) return image.resize(new_size) @staticmethod def _decode_data_image(data_image_str: str) -> Image.Image: header, data = data_image_str.split(',', 1) image_data = base64.b64decode(data) return Image.open(BytesIO(image_data)) def _process_input(self, texts: List[Union[str, Image.Image]]) -> tuple[List[str], List[Image.Image]]: processed_texts = [] processed_images = [] dummy_image = Image.new('RGB', (56, 56)) for sample in texts: if isinstance(sample, str): processed_texts.append(self.query_prompt % sample) processed_images.append(dummy_image) elif isinstance(sample, Image.Image): processed_texts.append(self.document_prompt) processed_images.append(self._resize_image(sample)) return processed_texts, processed_images def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: cache_position = torch.arange(0, features['input_ids'].shape[0]) inputs = self.model.prepare_inputs_for_generation( **features, cache_position=cache_position, use_cache=False ) with torch.no_grad(): output = self.model( **inputs, return_dict=True, output_hidden_states=True ) embeddings = output.hidden_states[-1][:, -1] features['sentence_embedding'] = torch.nn.functional.normalize( embeddings[:, :self.dimension], p=2, dim=-1 ) return features def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]: processed_texts, processed_images = self._process_input(texts) inputs = self.processor( text=processed_texts, images=processed_images, videos=None, padding=padding, return_tensors='pt' ) return {k: v.to(self.device) for k, v in inputs.items()} def save(self, output_path: str, safe_serialization: bool = True) -> None: """Save the model, tokenizer and processor to the given path.""" self.model.save_pretrained(output_path, safe_serialization=safe_serialization) self.processor.save_pretrained(output_path) # Save the configuration config = { 'model_name_or_path': output_path, 'max_pixels': self.max_pixels, 'min_pixels': self.min_pixels, 'dimension': self.dimension, 'max_seq_length': self.max_seq_length, } config_path = os.path.join(output_path, 'sentence_bert_config.json') with open(config_path, 'w') as f: json.dump(config, f) @staticmethod def load(input_path: str) -> 'Transformer': """Load a saved model from the given path.""" # Load configuration config_path = os.path.join(input_path, 'sentence_bert_config.json') if os.path.exists(config_path): with open(config_path) as f: config = json.load(f) else: config = {'model_name_or_path': input_path} return Transformer(**config)