hogepodge commited on
Commit
6307f85
·
1 Parent(s): d733ed5

Initial commit of the Label Studio Segment Anything space

Browse files

Implementation of a Label Studio ML backend using MobileSAM
for image segmentation.

Files changed (7) hide show
  1. Dockerfile +39 -0
  2. _wsgi.py +113 -0
  3. download_models.sh +23 -0
  4. model.py +145 -0
  5. requirements.txt +13 -0
  6. sam_predictor.py +198 -0
  7. start.sh +4 -0
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slimjjjjjj
2
+
3
+ # Install Dependencies
4
+ RUN apt-get update -q \
5
+ && apt-get install -qy --no-install-recommends wget git libopencv-dev python3-opencv \
6
+ && apt-get autoremove -y \
7
+ && apt-get clean \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Set up a non-root user
11
+ RUN useradd -m -u 1000 user \
12
+ && mkdir /app \
13
+ && chown -R user /app
14
+
15
+ # Switch to the "user" user
16
+ USER user
17
+
18
+ # Set the working directory to the user's home directory
19
+ WORKDIR /app
20
+
21
+ ENV PYTHONUNBUFFERED=True \
22
+ VITH_CHECKPOINT=/app/models/sam_vit_h_4b8939.pth \
23
+ MOBILESAM_CHECKPOINT=/app/models/mobile_sam.pt \
24
+ ONNX_CHECKPOINT=/app/models/sam_onnx_quantized_example.onnx \
25
+ PORT=7860
26
+
27
+ # Copy and run the model download script
28
+ COPY download_models.sh .
29
+ RUN bash /app/download_models.sh
30
+
31
+ # Install Python dependencies
32
+ COPY requirements.txt .
33
+ RUN pip install --user --no-cache-dir -r requirements.txt
34
+
35
+ COPY . ./
36
+
37
+ EXPOSE 7860
38
+
39
+ CMD ["/app/start.sh"]
_wsgi.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ import logging.config
5
+ import json
6
+
7
+ logging.config.dictConfig({
8
+ "version": 1,
9
+ "formatters": {
10
+ "standard": {
11
+ "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
12
+ }
13
+ },
14
+ "handlers": {
15
+ "console": {
16
+ "class": "logging.StreamHandler",
17
+ "level": os.getenv('LOG_LEVEL', 'INFO'),
18
+ "stream": "ext://sys.stdout",
19
+ "formatter": "standard"
20
+ }
21
+ },
22
+ "root": {
23
+ "level": os.getenv('LOG_LEVEL', 'INFO'),
24
+ "handlers": [
25
+ "console"
26
+ ],
27
+ "propagate": True
28
+ }
29
+ })
30
+
31
+ from label_studio_ml.api import init_app
32
+ from model import SamMLBackend
33
+
34
+ _DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
35
+
36
+
37
+ def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
38
+ if not os.path.exists(config_path):
39
+ return dict()
40
+ with open(config_path) as f:
41
+ config = json.load(f)
42
+ assert isinstance(config, dict)
43
+ return config
44
+
45
+
46
+ if __name__ == "__main__":
47
+ parser = argparse.ArgumentParser(description='Label studio')
48
+ parser.add_argument(
49
+ '-p', '--port', dest='port', type=int, default=9090,
50
+ help='Server port')
51
+ parser.add_argument(
52
+ '--host', dest='host', type=str, default='0.0.0.0',
53
+ help='Server host')
54
+ parser.add_argument(
55
+ '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
56
+ help='Additional LabelStudioMLBase model initialization kwargs')
57
+ parser.add_argument(
58
+ '-d', '--debug', dest='debug', action='store_true',
59
+ help='Switch debug mode')
60
+ parser.add_argument(
61
+ '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
62
+ help='Logging level')
63
+ parser.add_argument(
64
+ '--model-dir', dest='model_dir', default=os.path.dirname(__file__),
65
+ help='Directory where models are stored (relative to the project directory)')
66
+ parser.add_argument(
67
+ '--check', dest='check', action='store_true',
68
+ help='Validate model instance before launching server')
69
+
70
+ args = parser.parse_args()
71
+
72
+ # setup logging level
73
+ if args.log_level:
74
+ logging.root.setLevel(args.log_level)
75
+
76
+ def isfloat(value):
77
+ try:
78
+ float(value)
79
+ return True
80
+ except ValueError:
81
+ return False
82
+
83
+ def parse_kwargs():
84
+ param = dict()
85
+ for k, v in args.kwargs:
86
+ if v.isdigit():
87
+ param[k] = int(v)
88
+ elif v == 'True' or v == 'true':
89
+ param[k] = True
90
+ elif v == 'False' or v == 'False':
91
+ param[k] = False
92
+ elif isfloat(v):
93
+ param[k] = float(v)
94
+ else:
95
+ param[k] = v
96
+ return param
97
+
98
+ kwargs = get_kwargs_from_config()
99
+
100
+ if args.kwargs:
101
+ kwargs.update(parse_kwargs())
102
+
103
+ if args.check:
104
+ print('Check "' + SamMLBackend.__name__ + '" instance creation..')
105
+ model = SamMLBackend(**kwargs)
106
+
107
+ app = init_app(model_class=SamMLBackend)
108
+
109
+ app.run(host=args.host, port=args.port, debug=args.debug)
110
+
111
+ else:
112
+ # for uWSGI use
113
+ app = init_app(model_class=SamMLBackend)
download_models.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODELS_DIR="models"
4
+ mkdir -p ${MODELS_DIR}
5
+
6
+ download_model() {
7
+ FILE_PATH="${MODELS_DIR}/$1"
8
+ URL="$2"
9
+
10
+ if [ ! -f "${FILE_PATH}" ]; then
11
+ wget -q "${URL}" -P ${MODELS_DIR}/
12
+ fi
13
+ }
14
+
15
+ # Model files and their corresponding URLs
16
+ declare -A MODELS
17
+ # We just run with MobileSAM for this example
18
+ # MODELS["sam_vit_h_4b8939.pth"]="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
19
+ MODELS["mobile_sam.pt"]="https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
20
+
21
+ for model in "${!MODELS[@]}"; do
22
+ download_model "${model}" "${MODELS[${model}]}"
23
+ done
model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from label_studio_converter import brush
4
+ from typing import List, Dict, Optional
5
+ from uuid import uuid4
6
+ from sam_predictor import SAMPredictor
7
+ from label_studio_ml.model import LabelStudioMLBase
8
+
9
+ SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM") # other option is just SAM
10
+ PREDICTOR = SAMPredictor(SAM_CHOICE)
11
+
12
+
13
+ class SamMLBackend(LabelStudioMLBase):
14
+
15
+ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
16
+ """ Returns the predicted mask for a smart keypoint that has been placed."""
17
+
18
+ from_name, to_name, value = self.get_first_tag_occurence('BrushLabels', 'Image')
19
+
20
+ if not context or not context.get('result'):
21
+ # if there is no context, no interaction has happened yet
22
+ return []
23
+
24
+ image_width = context['result'][0]['original_width']
25
+ image_height = context['result'][0]['original_height']
26
+
27
+ # collect context information
28
+ point_coords = []
29
+ point_labels = []
30
+ input_box = None
31
+ selected_label = None
32
+ for ctx in context['result']:
33
+ x = ctx['value']['x'] * image_width / 100
34
+ y = ctx['value']['y'] * image_height / 100
35
+ ctx_type = ctx['type']
36
+ selected_label = ctx['value'][ctx_type][0]
37
+ if ctx_type == 'keypointlabels':
38
+ point_labels.append(int(ctx['is_positive']))
39
+ point_coords.append([int(x), int(y)])
40
+ elif ctx_type == 'rectanglelabels':
41
+ box_width = ctx['value']['width'] * image_width / 100
42
+ box_height = ctx['value']['height'] * image_height / 100
43
+ input_box = [int(x), int(y), int(box_width + x), int(box_height + y)]
44
+
45
+ print(f'Point coords are {point_coords}, point labels are {point_labels}, input box is {input_box}')
46
+
47
+ img_path = tasks[0]['data'][value]
48
+ predictor_results = PREDICTOR.predict(
49
+ img_path=img_path,
50
+ point_coords=point_coords or None,
51
+ point_labels=point_labels or None,
52
+ input_box=input_box
53
+ )
54
+
55
+ predictions = self.get_results(
56
+ masks=predictor_results['masks'],
57
+ probs=predictor_results['probs'],
58
+ width=image_width,
59
+ height=image_height,
60
+ from_name=from_name,
61
+ to_name=to_name,
62
+ label=selected_label)
63
+
64
+ return predictions
65
+
66
+ def get_results(self, masks, probs, width, height, from_name, to_name, label):
67
+ results = []
68
+ for mask, prob in zip(masks, probs):
69
+ # creates a random ID for your label everytime so no chance for errors
70
+ label_id = str(uuid4())[:4]
71
+ # converting the mask from the model to RLE format which is usable in Label Studio
72
+ mask = mask * 255
73
+ rle = brush.mask2rle(mask)
74
+
75
+ results.append({
76
+ 'id': label_id,
77
+ 'from_name': from_name,
78
+ 'to_name': to_name,
79
+ 'original_width': width,
80
+ 'original_height': height,
81
+ 'image_rotation': 0,
82
+ 'value': {
83
+ 'format': 'rle',
84
+ 'rle': rle,
85
+ 'brushlabels': [label],
86
+ },
87
+ 'score': prob,
88
+ 'type': 'brushlabels',
89
+ 'readonly': False
90
+ })
91
+
92
+ return [{
93
+ 'result': results,
94
+ 'model_version': PREDICTOR.model_name
95
+ }]
96
+
97
+
98
+ if __name__ == '__main__':
99
+ # test the model
100
+ model = SamMLBackend()
101
+ model.use_label_config('''
102
+ <View>
103
+ <Image name="image" value="$image" zoom="true"/>
104
+ <BrushLabels name="tag" toName="image">
105
+ <Label value="Banana" background="#FF0000"/>
106
+ <Label value="Orange" background="#0d14d3"/>
107
+ </BrushLabels>
108
+ <KeyPointLabels name="tag2" toName="image" smart="true" >
109
+ <Label value="Banana" background="#000000" showInline="true"/>
110
+ <Label value="Orange" background="#000000" showInline="true"/>
111
+ </KeyPointLabels>
112
+ <RectangleLabels name="tag3" toName="image" >
113
+ <Label value="Banana" background="#000000" showInline="true"/>
114
+ <Label value="Orange" background="#000000" showInline="true"/>
115
+ </RectangleLabels>
116
+ </View>
117
+ ''')
118
+ results = model.predict(
119
+ tasks=[{
120
+ 'data': {
121
+ 'image': 'https://s3.amazonaws.com/htx-pub/datasets/images/125245483_152578129892066_7843809718842085333_n.jpg'
122
+ }}],
123
+ context={
124
+ 'result': [{
125
+ 'original_width': 1080,
126
+ 'original_height': 1080,
127
+ 'image_rotation': 0,
128
+ 'value': {
129
+ 'x': 49.441786283891545,
130
+ 'y': 59.96810207336522,
131
+ 'width': 0.3189792663476874,
132
+ 'labels': ['Banana'],
133
+ 'keypointlabels': ['Banana']
134
+ },
135
+ 'is_positive': True,
136
+ 'id': 'fBWv1t0S2L',
137
+ 'from_name': 'tag2',
138
+ 'to_name': 'image',
139
+ 'type': 'keypointlabels',
140
+ 'origin': 'manual'
141
+ }]}
142
+ )
143
+ import json
144
+ results[0]['result'][0]['value']['rle'] = f'...{len(results[0]["result"][0]["value"]["rle"])} integers...'
145
+ print(json.dumps(results, indent=2))
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ label_studio_converter
2
+ opencv-python
3
+ onnxruntime
4
+ onnx
5
+ torch==2.0.1
6
+ torchvision==0.15.2
7
+ gunicorn==20.1.0
8
+ rq==1.10.1
9
+ timm==0.4.12
10
+
11
+ segment_anything @ git+https://github.com/facebookresearch/segment-anything.git
12
+ mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
13
+ label-studio-ml @ git+https://github.com/heartexlabs/label-studio-ml-backend.git
sam_predictor.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from typing import List, Dict, Optional
8
+ from label_studio_ml.utils import get_image_local_path, InMemoryLRUDictCache
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ VITH_CHECKPOINT = os.environ.get("VITH_CHECKPOINT")
13
+ ONNX_CHECKPOINT = os.environ.get("ONNX_CHECKPOINT")
14
+ MOBILESAM_CHECKPOINT = os.environ.get("MOBILESAM_CHECKPOINT", "mobile_sam.pt")
15
+ LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN")
16
+ LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST")
17
+
18
+
19
+ class SAMPredictor(object):
20
+
21
+ def __init__(self, model_choice):
22
+ self.model_choice = model_choice
23
+
24
+ # cache for embeddings
25
+ # TODO: currently it supports only one image in cache,
26
+ # since predictor.set_image() should be called each time the new image comes
27
+ # before making predictions
28
+ # to extend it to >1 image, we need to store the "active image" state in the cache
29
+ self.cache = InMemoryLRUDictCache(1)
30
+
31
+ # if you're not using CUDA, use "cpu" instead .... good luck not burning your computer lol
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ logger.debug(f"Using device {self.device}")
34
+
35
+ if model_choice == 'ONNX':
36
+ import onnxruntime
37
+ from segment_anything import sam_model_registry, SamPredictor
38
+
39
+ self.model_checkpoint = VITH_CHECKPOINT
40
+ if self.model_checkpoint is None:
41
+ raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint")
42
+ if ONNX_CHECKPOINT is None:
43
+ raise FileNotFoundError("ONNX_CHECKPOINT is not set: please set it to the path to the ONNX checkpoint")
44
+ logger.info(f"Using ONNX checkpoint {ONNX_CHECKPOINT} and SAM checkpoint {self.model_checkpoint}")
45
+
46
+ self.ort = onnxruntime.InferenceSession(ONNX_CHECKPOINT)
47
+ reg_key = "vit_h"
48
+
49
+ elif model_choice == 'SAM':
50
+ from segment_anything import SamPredictor, sam_model_registry
51
+
52
+ self.model_checkpoint = VITH_CHECKPOINT
53
+ if self.model_checkpoint is None:
54
+ raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint")
55
+
56
+ logger.info(f"Using SAM checkpoint {self.model_checkpoint}")
57
+ reg_key = "vit_h"
58
+
59
+ elif model_choice == 'MobileSAM':
60
+ from mobile_sam import SamPredictor, sam_model_registry
61
+
62
+ self.model_checkpoint = MOBILESAM_CHECKPOINT
63
+ if not self.model_checkpoint:
64
+ raise FileNotFoundError("MOBILE_CHECKPOINT is not set: please set it to the path to the MobileSAM checkpoint")
65
+ logger.info(f"Using MobileSAM checkpoint {self.model_checkpoint}")
66
+ reg_key = 'vit_t'
67
+ else:
68
+ raise ValueError(f"Invalid model choice {model_choice}")
69
+
70
+ sam = sam_model_registry[reg_key](checkpoint=self.model_checkpoint)
71
+ sam.to(device=self.device)
72
+ self.predictor = SamPredictor(sam)
73
+
74
+ @property
75
+ def model_name(self):
76
+ return f'{self.model_choice}:{self.model_checkpoint}:{self.device}'
77
+
78
+ def set_image(self, img_path, calculate_embeddings=True):
79
+ payload = self.cache.get(img_path)
80
+ if payload is None:
81
+ # Get image and embeddings
82
+ logger.debug(f'Payload not found for {img_path} in `IN_MEM_CACHE`: calculating from scratch')
83
+ image_path = get_image_local_path(
84
+ img_path,
85
+ label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN,
86
+ label_studio_host=LABEL_STUDIO_HOST
87
+ )
88
+ image = cv2.imread(image_path)
89
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
90
+ self.predictor.set_image(image)
91
+ payload = {'image_shape': image.shape[:2]}
92
+ logger.debug(f'Finished set_image({img_path}) in `IN_MEM_CACHE`: image shape {image.shape[:2]}')
93
+ if calculate_embeddings:
94
+ image_embedding = self.predictor.get_image_embedding().cpu().numpy()
95
+ payload['image_embedding'] = image_embedding
96
+ logger.debug(f'Finished storing embeddings for {img_path} in `IN_MEM_CACHE`: '
97
+ f'embedding shape {image_embedding.shape}')
98
+ self.cache.put(img_path, payload)
99
+ else:
100
+ logger.debug(f"Using embeddings for {img_path} from `IN_MEM_CACHE`")
101
+ return payload
102
+
103
+ def predict_onnx(
104
+ self,
105
+ img_path,
106
+ point_coords: Optional[List[List]] = None,
107
+ point_labels: Optional[List] = None,
108
+ input_box: Optional[List] = None
109
+ ):
110
+ # calculate embeddings
111
+ payload = self.set_image(img_path, calculate_embeddings=True)
112
+ image_shape = payload['image_shape']
113
+ image_embedding = payload['image_embedding']
114
+
115
+ onnx_point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
116
+ onnx_point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
117
+ onnx_box_coords = np.array(input_box, dtype=np.float32).reshape(2, 2) if input_box else None
118
+
119
+ onnx_coords, onnx_labels = None, None
120
+ if onnx_point_coords is not None and onnx_box_coords is not None:
121
+ # both keypoints and boxes are present
122
+ onnx_coords = np.concatenate([onnx_point_coords, onnx_box_coords], axis=0)[None, :, :]
123
+ onnx_labels = np.concatenate([onnx_point_labels, np.array([2, 3])], axis=0)[None, :].astype(np.float32)
124
+
125
+ elif onnx_point_coords is not None:
126
+ # only keypoints are present
127
+ onnx_coords = np.concatenate([onnx_point_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
128
+ onnx_labels = np.concatenate([onnx_point_labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
129
+
130
+ elif onnx_box_coords is not None:
131
+ # only boxes are present
132
+ raise NotImplementedError("Boxes without keypoints are not supported yet")
133
+
134
+ onnx_coords = self.predictor.transform.apply_coords(onnx_coords, image_shape).astype(np.float32)
135
+
136
+ # TODO: support mask inputs
137
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
138
+
139
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
140
+
141
+ ort_inputs = {
142
+ "image_embeddings": image_embedding,
143
+ "point_coords": onnx_coords,
144
+ "point_labels": onnx_labels,
145
+ "mask_input": onnx_mask_input,
146
+ "has_mask_input": onnx_has_mask_input,
147
+ "orig_im_size": np.array(image_shape, dtype=np.float32)
148
+ }
149
+
150
+ masks, prob, low_res_logits = self.ort.run(None, ort_inputs)
151
+ masks = masks > self.predictor.model.mask_threshold
152
+ mask = masks[0, 0, :, :].astype(np.uint8) # each mask has shape [H, W]
153
+ prob = float(prob[0][0])
154
+ # TODO: support the real multimask output as in https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
155
+ return {
156
+ 'masks': [mask],
157
+ 'probs': [prob]
158
+ }
159
+
160
+ def predict_sam(
161
+ self,
162
+ img_path,
163
+ point_coords: Optional[List[List]] = None,
164
+ point_labels: Optional[List] = None,
165
+ input_box: Optional[List] = None
166
+ ):
167
+ self.set_image(img_path, calculate_embeddings=False)
168
+ point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
169
+ point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
170
+ input_box = np.array(input_box, dtype=np.float32) if input_box else None
171
+
172
+ masks, probs, logits = self.predictor.predict(
173
+ point_coords=point_coords,
174
+ point_labels=point_labels,
175
+ box=input_box,
176
+ # TODO: support multimask output
177
+ multimask_output=False
178
+ )
179
+ mask = masks[0, :, :].astype(np.uint8) # each mask has shape [H, W]
180
+ prob = float(probs[0])
181
+ return {
182
+ 'masks': [mask],
183
+ 'probs': [prob]
184
+ }
185
+
186
+ def predict(
187
+ self, img_path: str,
188
+ point_coords: Optional[List[List]] = None,
189
+ point_labels: Optional[List] = None,
190
+ input_box: Optional[List] = None
191
+ ):
192
+ if self.model_choice == 'ONNX':
193
+ return self.predict_onnx(img_path, point_coords, point_labels, input_box)
194
+ elif self.model_choice in ('SAM', 'MobileSAM'):
195
+ return self.predict_sam(img_path, point_coords, point_labels, input_box)
196
+ else:
197
+ raise NotImplementedError(f"Model choice {self.model_choice} is not supported yet")
198
+
start.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Execute the gunicorn command
4
+ exec /home/user/.local/bin/gunicorn --preload --bind :$PORT --workers 1 --threads 8 --timeout 0 _wsgi:app