File size: 3,526 Bytes
af896ec 5c18c06 af896ec 941ce80 5c18c06 941ce80 af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec 8cda892 af896ec 8cda892 af896ec 8cda892 af896ec 8cda892 af896ec 8cda892 af896ec 5c18c06 af896ec 8cda892 af896ec 5c18c06 8cda892 5c18c06 8cda892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from PIL import Image
import base64
from io import BytesIO
import json
import sys
import os
import sys
CODE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "code")
sys.path.append(CODE_PATH)
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "model/tsbir_model_final.pt"
CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"
def load_model():
"""Load the model only once."""
global model
if "model" not in globals():
with open(CONFIG_PATH, 'r') as f:
model_info = json.load(f)
model = CLIP(**model_info)
checkpoint = torch.load(MODEL_PATH, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
global transformer
transformer = _transform(model.visual.input_resolution, is_train=False)
print("Model loaded successfully.")
def preprocess_image(image_base64):
"""Convert base64 encoded sketch to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(sketch_base64, text):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
sketch_tensor = preprocess_image(sketch_base64)
text_tensor = preprocess_text(text)
sketch_feature = model.encode_sketch(sketch_tensor)
text_feature = model.encode_text(text_tensor)
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
def get_image_embedding(image_base64):
"""Convert base64 encoded image to tensor."""
image_tensor = preprocess_image(image_base64)
with torch.no_grad():
image_feature = model.encode_image(image_tensor)
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
return image_feature.cpu().numpy().tolist()
# Hugging Face Inference API Entry Point
def infer(inputs):
"""
Inference API entry point.
Inputs:
- 'sketch': Base64 encoded sketch image.
- 'text': Text query.
"""
load_model() # Ensure the model is loaded once
if "sketch" in inputs:
sketch_base64 = inputs.get("sketch", "")
text_query = inputs.get("text", "")
if not sketch_base64 or not text_query:
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(sketch_base64, text_query)
return {"embedding": fused_embedding}
elif "image" in inputs:
image_base64 = inputs.get("image", "")
if not image_base64:
return {"error": "Image 'image' (base64) is required input."}
embedding = get_image_embedding(image_base64)
return {"embedding": embedding}
else:
return {"error": "Input 'sketch' or 'image' is required."}
|