tsbir / pipeline.py
Add custom inference handler
history blame
2.79 kB
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import os
from io import BytesIO
import json
import sys
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PreTrainedPipeline:
def __init__(self, path: str = ""):
Initialize the pipeline by loading the model.
path (str): Path to the directory containing model weights and config.
model_config_file = os.path.join(path, "code/training/model_configs/ViT-B-16.json")
with open(model_config_file, "r") as f:
model_info = json.load(f)
model_file = os.path.join(path, "model/tsbir_model_final.pt")
self.model = CLIP(**model_info)
checkpoint = torch.load(model_file, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module."):]: v for k, v in sd.items()}
self.model.load_state_dict(sd, strict=False)
self.model = self.model.to(device).eval()
# Preprocessing
self.transform = _transform(self.model.visual.input_resolution, is_train=False)
def __call__(self, data: Any) -> Dict[str, List[float]]:
Process the request and return the fused embedding.
data (dict): Includes 'image' (base64) and 'text' (str) inputs.
dict: {"fused_embedding": [float, float, ...]}
# Parse inputs
inputs = data.pop("inputs", data)
image_base64 = inputs.get("image", "")
text_query = inputs.get("text", "")
if not image_base64 or not text_query:
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
# Preprocess the image
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image_tensor = self.transform(image).unsqueeze(0).to(device)
# Preprocess the text
text_tensor = tokenize([str(text_query)])[0].unsqueeze(0).to(device)
# Generate features
with torch.no_grad():
sketch_feature = self.model.encode_sketch(image_tensor)
text_feature = self.model.encode_text(text_tensor)
# Normalize features
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
# Fuse features
fused_embedding = self.model.feature_fuse(sketch_feature, text_feature)
return {"fused_embedding": fused_embedding.cpu().numpy().tolist()}