File size: 3,097 Bytes
da77aaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Copyright (c) 2024-present Naver Cloud Corp.
This source code is based on code from the Segment Anything Model (SAM)
(https://github.com/facebookresearch/segment-anything).

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
from typing import Any, Callable
import onnxruntime
import numpy as np

def np2tensor(np_array, device):
    return torch.from_numpy(np_array).to(device)

def tensor2np(torch_tensor):
    if torch_tensor is None:
        return None
    
    return torch_tensor.detach().cpu().numpy()
    
class ZIM_Decoder():
    def __init__(self, onnx_path, num_threads=16):
        self.onnx_path = onnx_path
        
        sessionOptions = onnxruntime.SessionOptions()
        sessionOptions.intra_op_num_threads = num_threads
        sessionOptions.inter_op_num_threads = num_threads
        providers = ["CPUExecutionProvider"]

        self.ort_session = onnxruntime.InferenceSession(
            onnx_path, sess_options=sessionOptions, providers=providers
        )
        self.num_mask_tokens = 4
        
    def cuda(self, device_id=0):
        providers = [
            (
                "CUDAExecutionProvider",
                {
                    "device_id": device_id,
                },
            ),
        ]

        self.ort_session.set_providers(providers)
        
    def forward(
        self, 
        interm_feats,
        image_embeddings, 
        points, 
        boxes,
        attn_mask,
    ):
        device = image_embeddings.device
        
        ort_inputs = {
            "feat_D0": tensor2np(interm_feats[0]),
            "feat_D1": tensor2np(interm_feats[1]),
            "feat_D2": tensor2np(interm_feats[2]),
            "image_embeddings": tensor2np(image_embeddings),
            "attn_mask": tensor2np(attn_mask),
        }
        
        if points is not None:
            point_coords, point_labels = points
            ort_inputs["point_coords"] = tensor2np(point_coords.float())
            ort_inputs["point_labels"] = tensor2np(point_labels.float())
            
            # add paddings as done in SAM
            padding_point = np.zeros((ort_inputs["point_coords"].shape[0], 1, 2), dtype=np.float32) - 0.5
            padding_label = -np.ones((ort_inputs["point_labels"].shape[0], 1), dtype=np.float32)
            ort_inputs["point_coords"] = np.concatenate([ort_inputs["point_coords"], padding_point], axis=1)
            ort_inputs["point_labels"] = np.concatenate([ort_inputs["point_labels"], padding_label], axis=1)
            
        if boxes is not None:
            ort_inputs["point_coords"] = tensor2np(boxes.reshape(-1, 2, 2))
            ort_inputs["point_labels"] = np.array([[2, 3]], dtype=np.float32).repeat(boxes.shape[0], 0)
        
        masks, iou_predictions = self.ort_session.run(None, ort_inputs)
        
        masks = np2tensor(masks, device)
        iou_predictions = np2tensor(iou_predictions, device)
        
        return masks, iou_predictions
    
    __call__: Callable[..., Any] = forward