import sys from pathlib import Path from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" sys.path.append(str(lightglue_path)) from lightglue import LightGlue as LG import torch import os class LightGlue(BaseModel): default_conf = { "match_threshold": 0., "filter_threshold": 0.1, "width_confidence": 0.99, # for point pruning "depth_confidence": 0.95, # for early stopping, "features": "superpoint", "model_name": "superpoint_lightglue.pth", "flash": True, # enable FlashAttention if available. "mp": False, # enable mixed precision "add_scale_ori": False, } required_inputs = [ "image0", "keypoints0", "scores0", "descriptors0", "image1", "keypoints1", "scores1", "descriptors1", ] def _init(self, conf): logger.info("Loading lightglue model, {}".format(conf["model_name"])) print("Loading lightglue model, {}".format(conf["model_name"])) if conf["model_name"] == 'superpoint_minima_lightglue.pth': model_web_path = 'https://github.com/LSXI7/storage/releases/download/MINIMA/minima_lightglue.pth' weight_path = torch.hub.load_state_dict_from_url(model_web_path, map_location=torch.device('cpu')) cache_dir = torch.hub.get_dir() filename = "minima_lightglue.pth" print('cache_dir', cache_dir) print('filename', filename) print('os.path.join(cache_dir, filename)', os.path.join(cache_dir, filename)) # torch.hub.download_url_to_file(model_web_path, os.path.join(cache_dir, filename)) model_path = os.path.join(cache_dir, 'checkpoints', filename) conf['MINIMA'] = True conf['MINIMA_path'] = model_path else: model_path = self._download_model( repo_id=MODEL_REPO_ID, filename="{}/{}".format( Path(__file__).stem, self.conf["model_name"] ), ) # print("model_path:", model_path) conf["weights"] = str(model_path) print("conf:", conf["weights"]) conf["filter_threshold"] = conf["match_threshold"] self.net = LG(**conf) logger.info("Load lightglue model done.") def _forward(self, data): input = {} input["image0"] = { "image": data["image0"], "keypoints": data["keypoints0"], "descriptors": data["descriptors0"].permute(0, 2, 1), } if "scales0" in data: input["image0"] = {**input["image0"], "scales": data["scales0"]} if "oris0" in data: input["image0"] = {**input["image0"], "oris": data["oris0"]} input["image1"] = { "image": data["image1"], "keypoints": data["keypoints1"], "descriptors": data["descriptors1"].permute(0, 2, 1), } print('data["image0"]',data["image0"].shape) print('data["image1"]',data["image1"].shape) if "scales1" in data: input["image1"] = {**input["image1"], "scales": data["scales1"]} if "oris1" in data: input["image1"] = {**input["image1"], "oris": data["oris1"]} return self.net(input)