LiheYoung commited on
Commit
116da01
·
verified ·
1 Parent(s): 2493a1e

Load from Hugging Face

Browse files
Files changed (2) hide show
  1. depth_anything/dpt.py +21 -5
  2. requirements.txt +2 -0
depth_anything/dpt.py CHANGED
@@ -1,8 +1,10 @@
 
1
  import torch
2
  import torch.nn as nn
3
-
4
- from .blocks import FeatureFusionBlock, _make_scratch
5
  import torch.nn.functional as F
 
 
 
6
 
7
 
8
  def _make_fusion_block(features, use_bn, size = None):
@@ -143,7 +145,6 @@ class DPT_DINOv2(nn.Module):
143
  # in case the Internet connection is not stable, please load the DINOv2 locally
144
  if localhub:
145
  self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
146
- # self.pretrained.load_state_dict(torch.load('checkpoints/dinov2_{:}14_pretrain.pth'.format(encoder)))
147
  else:
148
  self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
149
 
@@ -165,7 +166,22 @@ class DPT_DINOv2(nn.Module):
165
  return depth.squeeze(1)
166
 
167
 
 
 
 
 
 
168
  if __name__ == '__main__':
169
- depth_anything = DPT_DINOv2()
170
- depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_dinov2_vitl14.pth'))
 
 
 
 
 
 
 
 
 
 
171
 
 
1
+ import argparse
2
  import torch
3
  import torch.nn as nn
 
 
4
  import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
+
7
+ from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
 
9
 
10
  def _make_fusion_block(features, use_bn, size = None):
 
145
  # in case the Internet connection is not stable, please load the DINOv2 locally
146
  if localhub:
147
  self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
 
148
  else:
149
  self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
150
 
 
166
  return depth.squeeze(1)
167
 
168
 
169
+ class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
170
+ def __init__(self, config):
171
+ super().__init__(**config)
172
+
173
+
174
  if __name__ == '__main__':
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument(
177
+ "--encoder",
178
+ default="vits",
179
+ type=str,
180
+ choices=["vits", "vitb", "vitl"],
181
+ )
182
+ args = parser.parse_args()
183
+
184
+ model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
185
+
186
+ print(model)
187
 
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  gradio_imageslider
 
2
  torch
3
  torchvision
4
  opencv-python
 
 
1
  gradio_imageslider
2
+ gradio==4.14.0
3
  torch
4
  torchvision
5
  opencv-python
6
+ huggingface_hub