|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
Processor class for Centurio.
|
|
"""
|
|
import timm
|
|
import torch
|
|
import transformers
|
|
from tokenizers import AddedToken
|
|
from torchvision.transforms import InterpolationMode, Compose, Resize, ToTensor, Normalize
|
|
from transformers import BaseImageProcessor, AutoTokenizer, AutoProcessor, AutoImageProcessor
|
|
from typing import List, Union, Optional
|
|
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
from transformers.utils import logging
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
class CenturioTimmImageProcessor(BaseImageProcessor):
|
|
r"""
|
|
|
|
"""
|
|
model_input_names = ["pixel_values"]
|
|
|
|
def __init__(
|
|
self,
|
|
timm_model="vit_so400m_patch14_siglip_384",
|
|
tiling=1,
|
|
**kwargs,
|
|
) -> None:
|
|
config = timm.get_pretrained_cfg(timm_model)
|
|
input_size = config.input_size[1]
|
|
self.timm_model = timm_model
|
|
self.interpolation = config.interpolation
|
|
self.mean = config.mean
|
|
self.std = config.std
|
|
self.tiling = tiling
|
|
self.input_size = (input_size, input_size)
|
|
|
|
|
|
def __call__(
|
|
self,
|
|
images: ImageInput,
|
|
**kwargs
|
|
):
|
|
return self.preprocess(images, **kwargs)
|
|
|
|
|
|
def preprocess(
|
|
self,
|
|
images: ImageInput,
|
|
**kwargs
|
|
):
|
|
transform = Compose([
|
|
Resize(self.input_size, interpolation=InterpolationMode(self.interpolation)),
|
|
ToTensor(),
|
|
Normalize(mean=self.mean, std=self.std)
|
|
])
|
|
if self.tiling > 1:
|
|
|
|
self.input_size_large = (self.input_size[0] * self.tiling, self.input_size[0] * self.tiling)
|
|
transform_large = Compose([
|
|
Resize(self.input_size_large, interpolation=InterpolationMode(self.interpolation)),
|
|
ToTensor(),
|
|
Normalize(mean=self.mean, std=self.std)
|
|
])
|
|
|
|
processed_images = []
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
for image_pil in images:
|
|
image = transform(image_pil)
|
|
if self.tiling > 1:
|
|
image_large = transform_large(image_pil)
|
|
h, w = self.input_size
|
|
img_large_split = [image_large[:, i * h:(i + 1) * h, j * w:(j + 1) * w] for i in range(self.tiling) for
|
|
j in range(self.tiling)]
|
|
processed_images.extend([image] + img_large_split)
|
|
else:
|
|
processed_images.append(image)
|
|
processed_images = torch.stack(processed_images, dim=0)
|
|
return BatchFeature(
|
|
data={"pixel_values": processed_images}
|
|
)
|
|
|
|
AutoImageProcessor.register("CenturioTimmImageProcessor", CenturioTimmImageProcessor)
|
|
|
|
transformers.CenturioTimmImageProcessor = CenturioTimmImageProcessor
|
|
|
|
class CenturioProcessor(ProcessorMixin):
|
|
attributes = ["image_processor", "tokenizer"]
|
|
optional_attributes = ["chat_template"]
|
|
image_processor_class = "CenturioTimmImageProcessor"
|
|
tokenizer_class = ("AutoTokenizer")
|
|
image_token="<image_placeholder>"
|
|
|
|
def __init__(
|
|
self,
|
|
image_processor=None,
|
|
tokenizer=None,
|
|
tiling=1,
|
|
**kwargs,
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.image_processor = image_processor
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
def __call__(
|
|
self,
|
|
images: ImageInput = None,
|
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
"""
|
|
"""
|
|
if images is None and text is None:
|
|
raise ValueError("You have to specify at least one of `images` or `text`.")
|
|
|
|
|
|
images, text = _validate_images_text_input_order(images, text)
|
|
|
|
if images is not None:
|
|
image_inputs = self.image_processor(images)
|
|
else:
|
|
image_inputs = {}
|
|
|
|
if isinstance(text, str):
|
|
text = [text]
|
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
|
|
|
prompt_strings = text
|
|
|
|
text_inputs = self.tokenizer(prompt_strings, **kwargs)
|
|
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
|
|
|
|
def batch_decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
|
refer to the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
|
|
def decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
|
the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |