File size: 5,067 Bytes
83ae704
 
 
f0a9ccb
 
 
 
 
 
 
afd2828
 
 
 
 
 
 
83ae704
 
f64a2f2
83ae704
06fd68d
 
 
 
 
 
 
821f77a
 
 
 
 
 
 
773c11a
 
 
 
 
 
 
7669514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07e7bdc
7669514
 
 
 
 
 
83ae704
 
 
 
7c18755
773c11a
 
 
bce5a1f
7c18755
83ae704
7c18755
 
bce5a1f
83ae704
4bfe017
aba1eee
7c18755
83ae704
 
 
edff93e
 
 
7c18755
edff93e
83ae704
 
 
edff93e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
import os
import torch
import tempfile
from contextlib import nullcontext

from mast3r.demo import get_args_parser, main_demo

from mast3r.model import AsymmetricMASt3R
from mast3r.utils.misc import hash_md5

import matplotlib.pyplot as pl
pl.ion()

torch.backends.cuda.matmul.allow_tf32 = True  # for GPU >= Ampere and PyTorch >= 1.12

import argparse

def get_args_parser():
    parser = argparse.ArgumentParser(description="MASt3R Demo")
    parser.add_argument("--weights", type=str, default=None, help="Path to the weights file.")
    parser.add_argument("--model_name", type=str, default=None, choices=[
        'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'], help="Name of the model to use.")
    parser.add_argument("--device", type=str, default='cuda', help="Device to run the model on.")
    parser.add_argument("--server_name", type=str, default=None, help="Server name to use.")
    parser.add_argument("--local_network", action='store_true', help="Run on local network.")
    parser.add_argument("--image_size", type=int, choices=[512, 224], default=512, help="Size of the images.")
    parser.add_argument("--server_port", type=int, default=None, help="Port for the server.")
    parser.add_argument("--tmp_dir", type=str, default=None, help="Temporary directory.")
    parser.add_argument("--silent", action='store_true', help="Run silently.")
    parser.add_argument("--share", default=True, action='store_true', help="Share the application.")
    parser.add_argument("--gradio_delete_cache", action='store_true', help="Delete Gradio cache.")
    return parser

def get_default_weights_path(model_name):
    # Construct default weights path based on model_name
    return f"naver/{model_name}"
if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()

    # Set default values for required arguments
    if args.weights is None and args.model_name is None:
        args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
    
    if args.weights is None:
        args.weights = f"naver/{args.model_name}"

    # Rest of the code for setting up the server and loading the model
    server_name = args.server_name or ('0.0.0.0' if args.local_network else '127.0.0.1')
    weights_path = args.weights

    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load the model
    model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
    chkpt_tag = hash_md5(weights_path)

    def get_context(tmp_dir):
        return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
            else nullcontext(tmp_dir)

    with get_context(args.tmp_dir) as tmpdirname:
        cache_path = os.path.join(tmpdirname, chkpt_tag)
        os.makedirs(cache_path, exist_ok=True)
        main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
                  share=args.share, gradio_delete_cache=args.gradio_delete_cache)