tsbir / inference.py
tcm03
Add image encoding
8cda892
import torch
from PIL import Image
import base64
from io import BytesIO
import json
import sys
import os
import sys
CODE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "code")
sys.path.append(CODE_PATH)
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "model/tsbir_model_final.pt"
CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"
def load_model():
"""Load the model only once."""
global model
if "model" not in globals():
with open(CONFIG_PATH, 'r') as f:
model_info = json.load(f)
model = CLIP(**model_info)
checkpoint = torch.load(MODEL_PATH, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
global transformer
transformer = _transform(model.visual.input_resolution, is_train=False)
print("Model loaded successfully.")
def preprocess_image(image_base64):
"""Convert base64 encoded sketch to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(sketch_base64, text):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
sketch_tensor = preprocess_image(sketch_base64)
text_tensor = preprocess_text(text)
sketch_feature = model.encode_sketch(sketch_tensor)
text_feature = model.encode_text(text_tensor)
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
def get_image_embedding(image_base64):
"""Convert base64 encoded image to tensor."""
image_tensor = preprocess_image(image_base64)
with torch.no_grad():
image_feature = model.encode_image(image_tensor)
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
return image_feature.cpu().numpy().tolist()
# Hugging Face Inference API Entry Point
def infer(inputs):
"""
Inference API entry point.
Inputs:
- 'sketch': Base64 encoded sketch image.
- 'text': Text query.
"""
load_model() # Ensure the model is loaded once
if "sketch" in inputs:
sketch_base64 = inputs.get("sketch", "")
text_query = inputs.get("text", "")
if not sketch_base64 or not text_query:
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(sketch_base64, text_query)
return {"embedding": fused_embedding}
elif "image" in inputs:
image_base64 = inputs.get("image", "")
if not image_base64:
return {"error": "Image 'image' (base64) is required input."}
embedding = get_image_embedding(image_base64)
return {"embedding": embedding}
else:
return {"error": "Input 'sketch' or 'image' is required."}