|
from typing import Dict, Any |
|
import torch |
|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import base64 |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
|
|
|
|
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") |
|
self.model = LlavaNextForConditionalGeneration.from_pretrained( |
|
"rroset/llava-v1.6-34b", |
|
quantization_config=quantization_config, |
|
device_map="auto" |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
logs = [] |
|
logs.append("Iniciant processament de la petici贸.") |
|
|
|
inputs = data.get("inputs") |
|
if not inputs: |
|
logs.append("Format d'entrada inv脿lid. Manca la clau 'inputs'.") |
|
return {"error": "Invalid input format. 'inputs' key is missing.", "logs": logs} |
|
|
|
image_url = inputs.get("url") |
|
image_data = inputs.get("image_data") |
|
prompt = inputs.get("prompt") |
|
max_tokens = inputs.get("max_tokens", 100) |
|
|
|
if not prompt: |
|
logs.append("S'ha de proporcionar 'prompt' en 'inputs'.") |
|
return {"error": "The 'prompt' must be provided in 'inputs'.", "logs": logs} |
|
|
|
if not image_url and not image_data: |
|
logs.append("S'ha de proporcionar 'url' o 'image_data' en 'inputs'.") |
|
return {"error": "Either 'url' or 'image_data' must be provided in 'inputs'.", "logs": logs} |
|
|
|
logs.append(f"Processant entrada: url={image_url}, image_data={'present' if image_data else 'absent'}, prompt={prompt}") |
|
|
|
try: |
|
if image_url: |
|
logs.append(f"Carregant imatge des de URL: {image_url}") |
|
response = requests.get(image_url, stream=True) |
|
image = Image.open(response.raw) |
|
elif image_data: |
|
logs.append("Carregant imatge des de dades d'imatge en brut.") |
|
image = Image.open(BytesIO(base64.b64decode(image_data))) |
|
|
|
if image.format == 'PNG': |
|
logs.append("Convertint imatge PNG a JPG.") |
|
image = image.convert('RGB') |
|
buffer = BytesIO() |
|
image.save(buffer, format="JPEG") |
|
buffer.seek(0) |
|
image = Image.open(buffer) |
|
|
|
except Exception as e: |
|
logs.append(f"Error carregant imatge: {str(e)}") |
|
return {"error": str(e), "logs": logs} |
|
|
|
try: |
|
logs.append("Processant imatge amb el model.") |
|
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda") |
|
output = self.model.generate(**inputs, max_new_tokens=max_tokens) |
|
result = self.processor.decode(output[0], skip_special_tokens=True) |
|
logs.append("Processament complet.") |
|
return {"input_prompt": prompt, "model_output": result, "logs": logs} |
|
|
|
except Exception as e: |
|
logs.append(f"Error processant el model: {str(e)}") |
|
return {"error": str(e), "logs": logs} |
|
|