File size: 5,067 Bytes
83ae704 f0a9ccb afd2828 83ae704 f64a2f2 83ae704 06fd68d 821f77a 773c11a 7669514 07e7bdc 7669514 83ae704 7c18755 773c11a bce5a1f 7c18755 83ae704 7c18755 bce5a1f 83ae704 4bfe017 aba1eee 7c18755 83ae704 edff93e 7c18755 edff93e 83ae704 edff93e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo executable
# --------------------------------------------------------
import os
import torch
import tempfile
from contextlib import nullcontext
from mast3r.demo import get_args_parser, main_demo
from mast3r.model import AsymmetricMASt3R
from mast3r.utils.misc import hash_md5
import matplotlib.pyplot as pl
pl.ion()
torch.backends.cuda.matmul.allow_tf32 = True # for GPU >= Ampere and PyTorch >= 1.12
import argparse
def get_args_parser():
parser = argparse.ArgumentParser(description="MASt3R Demo")
parser.add_argument("--weights", type=str, default=None, help="Path to the weights file.")
parser.add_argument("--model_name", type=str, default=None, choices=[
'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'], help="Name of the model to use.")
parser.add_argument("--device", type=str, default='cuda', help="Device to run the model on.")
parser.add_argument("--server_name", type=str, default=None, help="Server name to use.")
parser.add_argument("--local_network", action='store_true', help="Run on local network.")
parser.add_argument("--image_size", type=int, choices=[512, 224], default=512, help="Size of the images.")
parser.add_argument("--server_port", type=int, default=None, help="Port for the server.")
parser.add_argument("--tmp_dir", type=str, default=None, help="Temporary directory.")
parser.add_argument("--silent", action='store_true', help="Run silently.")
parser.add_argument("--share", default=True, action='store_true', help="Share the application.")
parser.add_argument("--gradio_delete_cache", action='store_true', help="Delete Gradio cache.")
return parser
def get_default_weights_path(model_name):
# Construct default weights path based on model_name
return f"naver/{model_name}"
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
# Set default values for required arguments
if args.weights is None and args.model_name is None:
args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
if args.weights is None:
args.weights = f"naver/{args.model_name}"
# Rest of the code for setting up the server and loading the model
server_name = args.server_name or ('0.0.0.0' if args.local_network else '127.0.0.1')
weights_path = args.weights
args.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
chkpt_tag = hash_md5(weights_path)
def get_context(tmp_dir):
return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
else nullcontext(tmp_dir)
with get_context(args.tmp_dir) as tmpdirname:
cache_path = os.path.join(tmpdirname, chkpt_tag)
os.makedirs(cache_path, exist_ok=True)
main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
share=args.share, gradio_delete_cache=args.gradio_delete_cache)
|