File size: 3,526 Bytes
af896ec
 
 
 
5c18c06
af896ec
941ce80
5c18c06
941ce80
 
 
af896ec
5c18c06
af896ec
 
 
5c18c06
 
 
 
 
 
 
 
 
af896ec
5c18c06
 
 
 
 
 
 
 
 
 
 
 
 
af896ec
8cda892
af896ec
 
 
 
 
 
 
 
8cda892
af896ec
 
8cda892
af896ec
 
8cda892
af896ec
 
 
 
 
 
 
 
8cda892
 
 
 
 
 
 
 
af896ec
 
 
5c18c06
af896ec
8cda892
af896ec
 
5c18c06
8cda892
 
 
 
 
5c18c06
8cda892
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from PIL import Image
import base64
from io import BytesIO
import json
import sys
import os

import sys
CODE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "code")
sys.path.append(CODE_PATH)
from clip.model import CLIP
from clip.clip import _transform, tokenize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_PATH = "model/tsbir_model_final.pt"
CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"

def load_model():
    """Load the model only once."""
    global model
    if "model" not in globals():
        with open(CONFIG_PATH, 'r') as f:
            model_info = json.load(f)

        model = CLIP(**model_info)
        checkpoint = torch.load(MODEL_PATH, map_location=device)
        sd = checkpoint["state_dict"]
        if next(iter(sd.items()))[0].startswith('module'):
            sd = {k[len('module.'):]: v for k, v in sd.items()}
        
        model.load_state_dict(sd, strict=False)
        model = model.to(device).eval()

        global transformer
        transformer = _transform(model.visual.input_resolution, is_train=False)
        print("Model loaded successfully.")

def preprocess_image(image_base64):
    """Convert base64 encoded sketch to tensor."""
    image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
    image = transformer(image).unsqueeze(0).to(device)
    return image

def preprocess_text(text):
    """Tokenize text query."""
    return tokenize([str(text)])[0].unsqueeze(0).to(device)

def get_fused_embedding(sketch_base64, text):
    """Fuse sketch and text features into a single embedding."""
    with torch.no_grad():
        sketch_tensor = preprocess_image(sketch_base64)
        text_tensor = preprocess_text(text)

        sketch_feature = model.encode_sketch(sketch_tensor)
        text_feature = model.encode_text(text_tensor)
        
        sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
        text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)

        fused_embedding = model.feature_fuse(sketch_feature, text_feature)
    return fused_embedding.cpu().numpy().tolist()

def get_image_embedding(image_base64):
    """Convert base64 encoded image to tensor."""
    image_tensor = preprocess_image(image_base64)
    with torch.no_grad():
        image_feature = model.encode_image(image_tensor)
        image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
    return image_feature.cpu().numpy().tolist()

# Hugging Face Inference API Entry Point
def infer(inputs):
    """
    Inference API entry point.
    Inputs:
      - 'sketch': Base64 encoded sketch image.
      - 'text': Text query.
    """
    load_model()  # Ensure the model is loaded once
    if "sketch" in inputs:
        sketch_base64 = inputs.get("sketch", "")
        text_query = inputs.get("text", "")
        if not sketch_base64 or not text_query:
            return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}

        # Generate Fused Embedding
        fused_embedding = get_fused_embedding(sketch_base64, text_query)
        return {"embedding": fused_embedding}
    elif "image" in inputs:
        image_base64 = inputs.get("image", "")
        if not image_base64:
            return {"error": "Image 'image' (base64) is required input."}
        embedding = get_image_embedding(image_base64)
        return {"embedding": embedding}
    else:
        return {"error": "Input 'sketch' or 'image' is required."}