scyonggg's picture
Initial commit
da77aaf
"""
Copyright (c) 2024-present Naver Cloud Corp.
This source code is based on code from the Segment Anything Model (SAM)
(https://github.com/facebookresearch/segment-anything).
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
from typing import Any, Callable
import onnxruntime
def np2tensor(np_array, device):
return torch.from_numpy(np_array).to(device)
def tensor2np(torch_tensor):
return torch_tensor.detach().cpu().numpy()
class ZIM_Encoder():
def __init__(self, onnx_path, num_threads=16):
self.onnx_path = onnx_path
sessionOptions = onnxruntime.SessionOptions()
sessionOptions.intra_op_num_threads = num_threads
sessionOptions.inter_op_num_threads = num_threads
providers = ["CPUExecutionProvider"]
self.ort_session = onnxruntime.InferenceSession(
onnx_path, sess_options=sessionOptions, providers=providers
)
def cuda(self, device_id=0):
providers = [
(
"CUDAExecutionProvider",
{
"device_id": device_id,
},
),
]
self.ort_session.set_providers(providers)
def forward(
self,
image,
):
device = image.device
ort_inputs = {
"image": tensor2np(image),
}
image_embeddings, feat_D0, feat_D1, feat_D2 = self.ort_session.run(None, ort_inputs)
image_embeddings = np2tensor(image_embeddings, device)
feat_D0 = np2tensor(feat_D0, device)
feat_D1 = np2tensor(feat_D1, device)
feat_D2 = np2tensor(feat_D2, device)
return image_embeddings, (feat_D0, feat_D1, feat_D2)
__call__: Callable[..., Any] = forward