MINIMA / hloc /matchers /lightglue.py
lsxi77777's picture
commit message
a930e1f
raw
history blame
3.4 kB
import sys
from pathlib import Path
from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel
lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
sys.path.append(str(lightglue_path))
from lightglue import LightGlue as LG
import torch
import os
class LightGlue(BaseModel):
default_conf = {
"match_threshold": 0.,
"filter_threshold": 0.1,
"width_confidence": 0.99, # for point pruning
"depth_confidence": 0.95, # for early stopping,
"features": "superpoint",
"model_name": "superpoint_lightglue.pth",
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"add_scale_ori": False,
}
required_inputs = [
"image0",
"keypoints0",
"scores0",
"descriptors0",
"image1",
"keypoints1",
"scores1",
"descriptors1",
]
def _init(self, conf):
logger.info("Loading lightglue model, {}".format(conf["model_name"]))
print("Loading lightglue model, {}".format(conf["model_name"]))
if conf["model_name"] == 'superpoint_minima_lightglue.pth':
model_web_path = 'https://github.com/LSXI7/storage/releases/download/MINIMA/minima_lightglue.pth'
weight_path = torch.hub.load_state_dict_from_url(model_web_path, map_location=torch.device('cpu'))
cache_dir = torch.hub.get_dir()
filename = "minima_lightglue.pth"
print('cache_dir', cache_dir)
print('filename', filename)
print('os.path.join(cache_dir, filename)', os.path.join(cache_dir, filename))
# torch.hub.download_url_to_file(model_web_path, os.path.join(cache_dir, filename))
model_path = os.path.join(cache_dir, 'checkpoints', filename)
conf['MINIMA'] = True
conf['MINIMA_path'] = model_path
else:
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
# print("model_path:", model_path)
conf["weights"] = str(model_path)
print("conf:", conf["weights"])
conf["filter_threshold"] = conf["match_threshold"]
self.net = LG(**conf)
logger.info("Load lightglue model done.")
def _forward(self, data):
input = {}
input["image0"] = {
"image": data["image0"],
"keypoints": data["keypoints0"],
"descriptors": data["descriptors0"].permute(0, 2, 1),
}
if "scales0" in data:
input["image0"] = {**input["image0"], "scales": data["scales0"]}
if "oris0" in data:
input["image0"] = {**input["image0"], "oris": data["oris0"]}
input["image1"] = {
"image": data["image1"],
"keypoints": data["keypoints1"],
"descriptors": data["descriptors1"].permute(0, 2, 1),
}
print('data["image0"]',data["image0"].shape)
print('data["image1"]',data["image1"].shape)
if "scales1" in data:
input["image1"] = {**input["image1"], "scales": data["scales1"]}
if "oris1" in data:
input["image1"] = {**input["image1"], "oris": data["oris1"]}
return self.net(input)