#!/usr/bin/env python3 from typing import Dict, List, Any from transformers import AutoProcessor, PaliGemmaForConditionalGeneration class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.model = PaliGemmaForConditionalGeneration.from_pretrained(path) self.processor = AutoProcessor.from_pretrained(path) pass def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ # pseudo # self.model(input) raw_inputs = data.pop("inputs", data) image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true" raw_image = Image.open(requests.get(image_file, stream=True).raw) inputs = self.processor(raw_inputs["prompt"], raw_image, return_tensors="pt") output = self.model.generate(**inputs, max_new_tokens=20) response = processor.decode(output[0], skip_special_tokens=True) return response # print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])