Update app.py
Browse files
app.py
CHANGED
@@ -52,19 +52,19 @@ if __name__ == '__main__':
|
|
52 |
parser = get_args_parser()
|
53 |
args = parser.parse_args()
|
54 |
|
55 |
-
# Set default
|
56 |
-
if args.weights is None
|
|
|
57 |
args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric' # Default model_name
|
|
|
58 |
|
59 |
if args.server_name is not None:
|
60 |
server_name = args.server_name
|
61 |
else:
|
62 |
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
else:
|
67 |
-
weights_path = "naver/" + args.model_name
|
68 |
|
69 |
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
70 |
chkpt_tag = hash_md5(weights_path)
|
|
|
52 |
parser = get_args_parser()
|
53 |
args = parser.parse_args()
|
54 |
|
55 |
+
# Set default value for `args.weights` if not provided
|
56 |
+
if args.weights is None:
|
57 |
+
# Set a default model_name if weights are not provided
|
58 |
args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric' # Default model_name
|
59 |
+
args.weights = "naver/" + args.model_name # Construct default weights path
|
60 |
|
61 |
if args.server_name is not None:
|
62 |
server_name = args.server_name
|
63 |
else:
|
64 |
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
65 |
|
66 |
+
# Use the provided or default weights_path
|
67 |
+
weights_path = args.weights
|
|
|
|
|
68 |
|
69 |
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
70 |
chkpt_tag = hash_md5(weights_path)
|