Beom0 commited on
Commit
33dba5a
·
verified ·
1 Parent(s): 7bb0d89

Update app.py (get weights using hf_hub_download)

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -18,6 +18,8 @@ import cv2
18
  from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
19
  from zim.utils import show_mat_anns
20
 
 
 
21
  def get_shortest_axis(image):
22
  h, w, _ = image.shape
23
  return h if h < w else w
@@ -213,12 +215,17 @@ def get_examples():
213
  images = os.listdir(assets_dir)
214
  return [os.path.join(assets_dir, img) for img in images]
215
 
216
- if __name__ == "__main__":
 
 
 
 
217
 
218
- backbone = "vit_l"
219
- ckpt_p = "ckpts/zim_vit_l_2092"
220
 
221
- model = zim_model_registry[backbone](checkpoint=ckpt_p)
 
 
 
222
  if torch.cuda.is_available():
223
  model.cuda()
224
 
 
18
  from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
19
  from zim.utils import show_mat_anns
20
 
21
+ from huggingface_hub import hf_hub_download
22
+
23
  def get_shortest_axis(image):
24
  h, w, _ = image.shape
25
  return h if h < w else w
 
215
  images = os.listdir(assets_dir)
216
  return [os.path.join(assets_dir, img) for img in images]
217
 
218
+ def download_onnx_weights(repo_id="naver-iv/zim-anything-vitl", file_dir="zim_vit_l_2092"):
219
+ hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/encoder.onnx")
220
+ filepath = hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/decoder.onnx")
221
+
222
+ return os.path.dirname(filepath)
223
 
 
 
224
 
225
+ if __name__ == "__main__":
226
+ backbone = "vit_l"
227
+ model = zim_model_registry[backbone](checkpoint=download_onnx_weights())
228
+
229
  if torch.cuda.is_available():
230
  model.cuda()
231