tcm03 commited on
Commit
effd0b9
·
1 Parent(s): 9d78c7b

Add custom inference handler

Browse files
Files changed (2) hide show
  1. README.md +4 -6
  2. pipeline.py +76 -0
README.md CHANGED
@@ -1,13 +1,11 @@
1
  ---
2
  tags:
3
- - image-retrieval
4
  - text-sketch
5
- - clip
6
- - open_clip
7
- - inference
8
- library_name: open_clip
9
  inference: true
10
- custom_handler: true
11
  ---
12
 
13
  # Image Retrieval with Text and Sketch
 
1
  ---
2
  tags:
3
+ - feature-extraction
4
  - text-sketch
5
+ - endpoints-template
6
+ library_name: generic
7
+ license: bsd-3-clause
 
8
  inference: true
 
9
  ---
10
 
11
  # Image Retrieval with Text and Sketch
pipeline.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ import torch
4
+ import base64
5
+ import os
6
+ from io import BytesIO
7
+ import json
8
+
9
+ import sys
10
+ sys.path.append("code")
11
+ from clip.model import CLIP
12
+ from clip.clip import _transform, tokenize
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ class PreTrainedPipeline:
17
+ def __init__(self, path: str = ""):
18
+ """
19
+ Initialize the pipeline by loading the model.
20
+ Args:
21
+ path (str): Path to the directory containing model weights and config.
22
+ """
23
+ model_config_file = os.path.join(path, "code/training/model_configs/ViT-B-16.json")
24
+ with open(model_config_file, "r") as f:
25
+ model_info = json.load(f)
26
+
27
+ model_file = os.path.join(path, "model/tsbir_model_final.pt")
28
+ self.model = CLIP(**model_info)
29
+ checkpoint = torch.load(model_file, map_location=device)
30
+
31
+ sd = checkpoint["state_dict"]
32
+ if next(iter(sd.items()))[0].startswith("module"):
33
+ sd = {k[len("module."):]: v for k, v in sd.items()}
34
+
35
+ self.model.load_state_dict(sd, strict=False)
36
+ self.model = self.model.to(device).eval()
37
+
38
+ # Preprocessing
39
+ self.transform = _transform(self.model.visual.input_resolution, is_train=False)
40
+
41
+ def __call__(self, data: Any) -> Dict[str, List[float]]:
42
+ """
43
+ Process the request and return the fused embedding.
44
+ Args:
45
+ data (dict): Includes 'image' (base64) and 'text' (str) inputs.
46
+ Returns:
47
+ dict: {"fused_embedding": [float, float, ...]}
48
+ """
49
+ # Parse inputs
50
+ inputs = data.pop("inputs", data)
51
+ image_base64 = inputs.get("image", "")
52
+ text_query = inputs.get("text", "")
53
+
54
+ if not image_base64 or not text_query:
55
+ return {"error": "Both 'image' (base64) and 'text' are required inputs."}
56
+
57
+ # Preprocess the image
58
+ image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
59
+ image_tensor = self.transform(image).unsqueeze(0).to(device)
60
+
61
+ # Preprocess the text
62
+ text_tensor = tokenize([str(text_query)])[0].unsqueeze(0).to(device)
63
+
64
+ # Generate features
65
+ with torch.no_grad():
66
+ sketch_feature = self.model.encode_sketch(image_tensor)
67
+ text_feature = self.model.encode_text(text_tensor)
68
+
69
+ # Normalize features
70
+ sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
71
+ text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
72
+
73
+ # Fuse features
74
+ fused_embedding = self.model.feature_fuse(sketch_feature, text_feature)
75
+
76
+ return {"fused_embedding": fused_embedding.cpu().numpy().tolist()}