Update app.py (get weights using hf_hub_download)
Browse files
app.py
CHANGED
@@ -18,6 +18,8 @@ import cv2
|
|
18 |
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
|
19 |
from zim.utils import show_mat_anns
|
20 |
|
|
|
|
|
21 |
def get_shortest_axis(image):
|
22 |
h, w, _ = image.shape
|
23 |
return h if h < w else w
|
@@ -213,12 +215,17 @@ def get_examples():
|
|
213 |
images = os.listdir(assets_dir)
|
214 |
return [os.path.join(assets_dir, img) for img in images]
|
215 |
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
|
218 |
-
backbone = "vit_l"
|
219 |
-
ckpt_p = "ckpts/zim_vit_l_2092"
|
220 |
|
221 |
-
|
|
|
|
|
|
|
222 |
if torch.cuda.is_available():
|
223 |
model.cuda()
|
224 |
|
|
|
18 |
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
|
19 |
from zim.utils import show_mat_anns
|
20 |
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
|
23 |
def get_shortest_axis(image):
|
24 |
h, w, _ = image.shape
|
25 |
return h if h < w else w
|
|
|
215 |
images = os.listdir(assets_dir)
|
216 |
return [os.path.join(assets_dir, img) for img in images]
|
217 |
|
218 |
+
def download_onnx_weights(repo_id="naver-iv/zim-anything-vitl", file_dir="zim_vit_l_2092"):
|
219 |
+
hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/encoder.onnx")
|
220 |
+
filepath = hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/decoder.onnx")
|
221 |
+
|
222 |
+
return os.path.dirname(filepath)
|
223 |
|
|
|
|
|
224 |
|
225 |
+
if __name__ == "__main__":
|
226 |
+
backbone = "vit_l"
|
227 |
+
model = zim_model_registry[backbone](checkpoint=download_onnx_weights())
|
228 |
+
|
229 |
if torch.cuda.is_available():
|
230 |
model.cuda()
|
231 |
|