zino36 commited on
Commit
bce5a1f
·
verified ·
1 Parent(s): e31b8c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -52,19 +52,19 @@ if __name__ == '__main__':
52
  parser = get_args_parser()
53
  args = parser.parse_args()
54
 
55
- # Set default values for required arguments
56
- if args.weights is None and args.model_name is None:
 
57
  args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric' # Default model_name
 
58
 
59
  if args.server_name is not None:
60
  server_name = args.server_name
61
  else:
62
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
63
 
64
- if args.weights is not None:
65
- weights_path = args.weights
66
- else:
67
- weights_path = "naver/" + args.model_name
68
 
69
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
70
  chkpt_tag = hash_md5(weights_path)
 
52
  parser = get_args_parser()
53
  args = parser.parse_args()
54
 
55
+ # Set default value for `args.weights` if not provided
56
+ if args.weights is None:
57
+ # Set a default model_name if weights are not provided
58
  args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric' # Default model_name
59
+ args.weights = "naver/" + args.model_name # Construct default weights path
60
 
61
  if args.server_name is not None:
62
  server_name = args.server_name
63
  else:
64
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
65
 
66
+ # Use the provided or default weights_path
67
+ weights_path = args.weights
 
 
68
 
69
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
70
  chkpt_tag = hash_md5(weights_path)