|
""" |
|
Copyright (c) 2024-present Naver Cloud Corp. |
|
This source code is based on code from the Segment Anything Model (SAM) |
|
(https://github.com/facebookresearch/segment-anything). |
|
|
|
This source code is licensed under the license found in the |
|
LICENSE file in the root directory of this source tree. |
|
""" |
|
import torch |
|
from typing import Any, Callable |
|
import onnxruntime |
|
import numpy as np |
|
|
|
def np2tensor(np_array, device): |
|
return torch.from_numpy(np_array).to(device) |
|
|
|
def tensor2np(torch_tensor): |
|
if torch_tensor is None: |
|
return None |
|
|
|
return torch_tensor.detach().cpu().numpy() |
|
|
|
class ZIM_Decoder(): |
|
def __init__(self, onnx_path, num_threads=16): |
|
self.onnx_path = onnx_path |
|
|
|
sessionOptions = onnxruntime.SessionOptions() |
|
sessionOptions.intra_op_num_threads = num_threads |
|
sessionOptions.inter_op_num_threads = num_threads |
|
providers = ["CPUExecutionProvider"] |
|
|
|
self.ort_session = onnxruntime.InferenceSession( |
|
onnx_path, sess_options=sessionOptions, providers=providers |
|
) |
|
self.num_mask_tokens = 4 |
|
|
|
def cuda(self, device_id=0): |
|
providers = [ |
|
( |
|
"CUDAExecutionProvider", |
|
{ |
|
"device_id": device_id, |
|
}, |
|
), |
|
] |
|
|
|
self.ort_session.set_providers(providers) |
|
|
|
def forward( |
|
self, |
|
interm_feats, |
|
image_embeddings, |
|
points, |
|
boxes, |
|
attn_mask, |
|
): |
|
device = image_embeddings.device |
|
|
|
ort_inputs = { |
|
"feat_D0": tensor2np(interm_feats[0]), |
|
"feat_D1": tensor2np(interm_feats[1]), |
|
"feat_D2": tensor2np(interm_feats[2]), |
|
"image_embeddings": tensor2np(image_embeddings), |
|
"attn_mask": tensor2np(attn_mask), |
|
} |
|
|
|
if points is not None: |
|
point_coords, point_labels = points |
|
ort_inputs["point_coords"] = tensor2np(point_coords.float()) |
|
ort_inputs["point_labels"] = tensor2np(point_labels.float()) |
|
|
|
|
|
padding_point = np.zeros((ort_inputs["point_coords"].shape[0], 1, 2), dtype=np.float32) - 0.5 |
|
padding_label = -np.ones((ort_inputs["point_labels"].shape[0], 1), dtype=np.float32) |
|
ort_inputs["point_coords"] = np.concatenate([ort_inputs["point_coords"], padding_point], axis=1) |
|
ort_inputs["point_labels"] = np.concatenate([ort_inputs["point_labels"], padding_label], axis=1) |
|
|
|
if boxes is not None: |
|
ort_inputs["point_coords"] = tensor2np(boxes.reshape(-1, 2, 2)) |
|
ort_inputs["point_labels"] = np.array([[2, 3]], dtype=np.float32).repeat(boxes.shape[0], 0) |
|
|
|
masks, iou_predictions = self.ort_session.run(None, ort_inputs) |
|
|
|
masks = np2tensor(masks, device) |
|
iou_predictions = np2tensor(iou_predictions, device) |
|
|
|
return masks, iou_predictions |
|
|
|
__call__: Callable[..., Any] = forward |
|
|