lsxi77777 commited on
Commit
aaf6642
·
1 Parent(s): 89426d4

fix on gpu

Browse files
Files changed (2) hide show
  1. ui/app_class.py +0 -1
  2. ui/utils.py +88 -88
ui/app_class.py CHANGED
@@ -83,7 +83,6 @@ a:hover {
83
  }
84
  """
85
 
86
- @spaces.GPU
87
  class ImageMatchingApp:
88
  def __init__(self, server_name="0.0.0.0", server_port=7860, **kwargs):
89
  self.server_name = server_name
 
83
  }
84
  """
85
 
 
86
  class ImageMatchingApp:
87
  def __init__(self, server_name="0.0.0.0", server_port=7860, **kwargs):
88
  self.server_name = server_name
ui/utils.py CHANGED
@@ -1,22 +1,21 @@
 
 
 
 
1
  import os
2
  import pickle
 
 
3
  import random
4
  import shutil
5
  import sys
6
  import time
7
  import warnings
 
8
  from itertools import combinations
9
  from pathlib import Path
10
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
 
12
- import cv2
13
- import gradio as gr
14
- import matplotlib.pyplot as plt
15
- import numpy as np
16
- import poselib
17
- import psutil
18
- from PIL import Image
19
-
20
  sys.path.append(str(Path(__file__).parents[1]))
21
 
22
  from hloc import (
@@ -30,6 +29,7 @@ from hloc import (
30
  )
31
  from hloc.utils.base_model import dynamic_load
32
  from ui.viz import display_keypoints, display_matches, fig2im, plot_images
 
33
 
34
  warnings.simplefilter("ignore")
35
 
@@ -130,7 +130,7 @@ def load_config(config_name: str) -> Dict[str, Any]:
130
 
131
 
132
  def get_matcher_zoo(
133
- matcher_zoo: Dict[str, Dict[str, Union[str, bool]]]
134
  ) -> Dict[str, Dict[str, Union[Callable, bool]]]:
135
  """
136
  Restore matcher configurations from a dictionary.
@@ -220,7 +220,7 @@ def gen_examples():
220
  img1 = os.path.join(path, lines[i].strip())
221
  img2 = os.path.join(path, lines[i + 1].strip())
222
  image_pairs.append((img1, img2))
223
- count=len(image_pairs)
224
 
225
  if len(image_pairs) < count:
226
  count = len(image_pairs)
@@ -276,13 +276,13 @@ def set_null_pred(feature_type: str, pred: dict):
276
 
277
 
278
  def _filter_matches_opencv(
279
- kp0: np.ndarray,
280
- kp1: np.ndarray,
281
- method: int = cv2.RANSAC,
282
- reproj_threshold: float = 3.0,
283
- confidence: float = 0.99,
284
- max_iter: int = 2000,
285
- geometry_type: str = "Homography",
286
  ) -> Tuple[np.ndarray, np.ndarray]:
287
  """
288
  Filters matches between two sets of keypoints using OpenCV's findHomography.
@@ -322,13 +322,13 @@ def _filter_matches_opencv(
322
 
323
 
324
  def _filter_matches_poselib(
325
- kp0: np.ndarray,
326
- kp1: np.ndarray,
327
- method: int = None, # not used
328
- reproj_threshold: float = 3,
329
- confidence: float = 0.99,
330
- max_iter: int = 2000,
331
- geometry_type: str = "Homography",
332
  ) -> dict:
333
  """
334
  Filters matches between two sets of keypoints using the poselib library.
@@ -364,13 +364,13 @@ def _filter_matches_poselib(
364
 
365
 
366
  def proc_ransac_matches(
367
- mkpts0: np.ndarray,
368
- mkpts1: np.ndarray,
369
- ransac_method: str = DEFAULT_RANSAC_METHOD,
370
- ransac_reproj_threshold: float = 3.0,
371
- ransac_confidence: float = 0.99,
372
- ransac_max_iter: int = 2000,
373
- geometry_type: str = "Homography",
374
  ):
375
  if ransac_method.startswith("CV2"):
376
  logger.info(
@@ -403,12 +403,12 @@ def proc_ransac_matches(
403
 
404
 
405
  def filter_matches(
406
- pred: Dict[str, Any],
407
- ransac_method: str = DEFAULT_RANSAC_METHOD,
408
- ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
409
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
410
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
411
- ransac_estimator: str = None,
412
  ):
413
  """
414
  Filter matches using RANSAC. If keypoints are available, filter by keypoints.
@@ -433,8 +433,8 @@ def filter_matches(
433
  mkpts1 = pred["mkeypoints1_orig"]
434
  feature_type = "KEYPOINT"
435
  elif (
436
- "line_keypoints0_orig" in pred.keys()
437
- and "line_keypoints1_orig" in pred.keys()
438
  ):
439
  mkpts0 = pred["line_keypoints0_orig"]
440
  mkpts1 = pred["line_keypoints1_orig"]
@@ -477,11 +477,11 @@ def filter_matches(
477
 
478
 
479
  def compute_geometry(
480
- pred: Dict[str, Any],
481
- ransac_method: str = DEFAULT_RANSAC_METHOD,
482
- ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
483
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
484
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
485
  ) -> Dict[str, List[float]]:
486
  """
487
  Compute geometric information of matches, including Fundamental matrix,
@@ -504,8 +504,8 @@ def compute_geometry(
504
  mkpts0 = pred["mkeypoints0_orig"]
505
  mkpts1 = pred["mkeypoints1_orig"]
506
  elif (
507
- "line_keypoints0_orig" in pred.keys()
508
- and "line_keypoints1_orig" in pred.keys()
509
  ):
510
  mkpts0 = pred["line_keypoints0_orig"]
511
  mkpts1 = pred["line_keypoints1_orig"]
@@ -561,10 +561,10 @@ def compute_geometry(
561
 
562
 
563
  def wrap_images(
564
- img0: np.ndarray,
565
- img1: np.ndarray,
566
- geo_info: Optional[Dict[str, List[float]]],
567
- geom_type: str,
568
  ) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]:
569
  """
570
  Wraps the images based on the geometric transformation used to align them.
@@ -617,10 +617,10 @@ def wrap_images(
617
 
618
 
619
  def generate_warp_images(
620
- input_image0: np.ndarray,
621
- input_image1: np.ndarray,
622
- matches_info: Dict[str, Any],
623
- choice: str,
624
  ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
625
  """
626
  Changes the estimate of the geometric transformation used to align the images.
@@ -635,9 +635,9 @@ def generate_warp_images(
635
  A tuple containing the updated images and the warpped images.
636
  """
637
  if (
638
- matches_info is None
639
- or len(matches_info) < 1
640
- or "geom_info" not in matches_info.keys()
641
  ):
642
  return None, None
643
  geom_info = matches_info["geom_info"]
@@ -671,12 +671,12 @@ def send_to_match(state_cache: Dict[str, Any]):
671
 
672
 
673
  def run_ransac(
674
- state_cache: Dict[str, Any],
675
- choice_geometry_type: str,
676
- ransac_method: str = DEFAULT_RANSAC_METHOD,
677
- ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
678
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
679
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
680
  ) -> Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]:
681
  """
682
  Run RANSAC matches and return the output images and the number of matches.
@@ -710,7 +710,7 @@ def run_ransac(
710
  ransac_confidence=ransac_confidence,
711
  ransac_max_iter=ransac_max_iter,
712
  )
713
- logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
714
  t1 = time.time()
715
 
716
  # plot images with ransac matches
@@ -721,7 +721,7 @@ def run_ransac(
721
  output_matches_ransac, num_matches_ransac = display_matches(
722
  state_cache, titles=titles, tag="KPTS_RANSAC"
723
  )
724
- logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
725
  t1 = time.time()
726
 
727
  # compute warp images
@@ -753,24 +753,24 @@ def run_ransac(
753
  tmp_state_cache,
754
  )
755
 
756
-
757
  def run_matching(
758
- image0: np.ndarray,
759
- image1: np.ndarray,
760
- match_threshold: float,
761
- extract_max_keypoints: int,
762
- keypoint_threshold: float,
763
- key: str,
764
- ransac_method: str = DEFAULT_RANSAC_METHOD,
765
- ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
766
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
767
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
768
- choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
769
- matcher_zoo: Dict[str, Any] = None,
770
- force_resize: bool = False,
771
- image_width: int = 640,
772
- image_height: int = 480,
773
- use_cached_model: bool = False,
774
  ) -> Tuple[
775
  np.ndarray,
776
  np.ndarray,
@@ -846,7 +846,7 @@ def run_matching(
846
  else:
847
  matcher = get_model(match_conf)
848
  print('match_conf2', match_conf)
849
- logger.info(f"Loading model using: {time.time()-t0:.3f}s")
850
  t1 = time.time()
851
 
852
  if model["dense"]:
@@ -899,13 +899,13 @@ def run_matching(
899
  )
900
  pred = match_features.match_images(matcher, pred0, pred1)
901
  # print('pred', pred)
902
- mconf= pred["mconf"]
903
  print('mconf', mconf.min(), mconf.max())
904
  del extractor
905
  # gr.Info(
906
  # f"Matching images done using: {time.time()-t1:.3f}s",
907
  # )
908
- logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
909
  t1 = time.time()
910
 
911
  # plot images with keypoints
@@ -932,7 +932,7 @@ def run_matching(
932
  )
933
 
934
  # gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
935
- logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
936
  t1 = time.time()
937
 
938
  # plot images with ransac matches
@@ -944,7 +944,7 @@ def run_matching(
944
  pred, titles=titles, tag="KPTS_RANSAC"
945
  )
946
  # gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
947
- logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
948
 
949
  t1 = time.time()
950
  # plot wrapped images
@@ -956,7 +956,7 @@ def run_matching(
956
  )
957
  plt.close("all")
958
  # gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
959
- logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
960
 
961
  state_cache = pred
962
  state_cache["num_matches_raw"] = num_matches_raw
 
1
+ import cv2
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
  import os
6
  import pickle
7
+ import poselib
8
+ import psutil
9
  import random
10
  import shutil
11
  import sys
12
  import time
13
  import warnings
14
+ from PIL import Image
15
  from itertools import combinations
16
  from pathlib import Path
17
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
 
 
 
 
 
 
 
 
 
19
  sys.path.append(str(Path(__file__).parents[1]))
20
 
21
  from hloc import (
 
29
  )
30
  from hloc.utils.base_model import dynamic_load
31
  from ui.viz import display_keypoints, display_matches, fig2im, plot_images
32
+ import spaces
33
 
34
  warnings.simplefilter("ignore")
35
 
 
130
 
131
 
132
  def get_matcher_zoo(
133
+ matcher_zoo: Dict[str, Dict[str, Union[str, bool]]]
134
  ) -> Dict[str, Dict[str, Union[Callable, bool]]]:
135
  """
136
  Restore matcher configurations from a dictionary.
 
220
  img1 = os.path.join(path, lines[i].strip())
221
  img2 = os.path.join(path, lines[i + 1].strip())
222
  image_pairs.append((img1, img2))
223
+ count = len(image_pairs)
224
 
225
  if len(image_pairs) < count:
226
  count = len(image_pairs)
 
276
 
277
 
278
  def _filter_matches_opencv(
279
+ kp0: np.ndarray,
280
+ kp1: np.ndarray,
281
+ method: int = cv2.RANSAC,
282
+ reproj_threshold: float = 3.0,
283
+ confidence: float = 0.99,
284
+ max_iter: int = 2000,
285
+ geometry_type: str = "Homography",
286
  ) -> Tuple[np.ndarray, np.ndarray]:
287
  """
288
  Filters matches between two sets of keypoints using OpenCV's findHomography.
 
322
 
323
 
324
  def _filter_matches_poselib(
325
+ kp0: np.ndarray,
326
+ kp1: np.ndarray,
327
+ method: int = None, # not used
328
+ reproj_threshold: float = 3,
329
+ confidence: float = 0.99,
330
+ max_iter: int = 2000,
331
+ geometry_type: str = "Homography",
332
  ) -> dict:
333
  """
334
  Filters matches between two sets of keypoints using the poselib library.
 
364
 
365
 
366
  def proc_ransac_matches(
367
+ mkpts0: np.ndarray,
368
+ mkpts1: np.ndarray,
369
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
370
+ ransac_reproj_threshold: float = 3.0,
371
+ ransac_confidence: float = 0.99,
372
+ ransac_max_iter: int = 2000,
373
+ geometry_type: str = "Homography",
374
  ):
375
  if ransac_method.startswith("CV2"):
376
  logger.info(
 
403
 
404
 
405
  def filter_matches(
406
+ pred: Dict[str, Any],
407
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
408
+ ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
409
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
410
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
411
+ ransac_estimator: str = None,
412
  ):
413
  """
414
  Filter matches using RANSAC. If keypoints are available, filter by keypoints.
 
433
  mkpts1 = pred["mkeypoints1_orig"]
434
  feature_type = "KEYPOINT"
435
  elif (
436
+ "line_keypoints0_orig" in pred.keys()
437
+ and "line_keypoints1_orig" in pred.keys()
438
  ):
439
  mkpts0 = pred["line_keypoints0_orig"]
440
  mkpts1 = pred["line_keypoints1_orig"]
 
477
 
478
 
479
  def compute_geometry(
480
+ pred: Dict[str, Any],
481
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
482
+ ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
483
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
484
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
485
  ) -> Dict[str, List[float]]:
486
  """
487
  Compute geometric information of matches, including Fundamental matrix,
 
504
  mkpts0 = pred["mkeypoints0_orig"]
505
  mkpts1 = pred["mkeypoints1_orig"]
506
  elif (
507
+ "line_keypoints0_orig" in pred.keys()
508
+ and "line_keypoints1_orig" in pred.keys()
509
  ):
510
  mkpts0 = pred["line_keypoints0_orig"]
511
  mkpts1 = pred["line_keypoints1_orig"]
 
561
 
562
 
563
  def wrap_images(
564
+ img0: np.ndarray,
565
+ img1: np.ndarray,
566
+ geo_info: Optional[Dict[str, List[float]]],
567
+ geom_type: str,
568
  ) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]:
569
  """
570
  Wraps the images based on the geometric transformation used to align them.
 
617
 
618
 
619
  def generate_warp_images(
620
+ input_image0: np.ndarray,
621
+ input_image1: np.ndarray,
622
+ matches_info: Dict[str, Any],
623
+ choice: str,
624
  ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
625
  """
626
  Changes the estimate of the geometric transformation used to align the images.
 
635
  A tuple containing the updated images and the warpped images.
636
  """
637
  if (
638
+ matches_info is None
639
+ or len(matches_info) < 1
640
+ or "geom_info" not in matches_info.keys()
641
  ):
642
  return None, None
643
  geom_info = matches_info["geom_info"]
 
671
 
672
 
673
  def run_ransac(
674
+ state_cache: Dict[str, Any],
675
+ choice_geometry_type: str,
676
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
677
+ ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
678
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
679
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
680
  ) -> Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]:
681
  """
682
  Run RANSAC matches and return the output images and the number of matches.
 
710
  ransac_confidence=ransac_confidence,
711
  ransac_max_iter=ransac_max_iter,
712
  )
713
+ logger.info(f"RANSAC matches done using: {time.time() - t1:.3f}s")
714
  t1 = time.time()
715
 
716
  # plot images with ransac matches
 
721
  output_matches_ransac, num_matches_ransac = display_matches(
722
  state_cache, titles=titles, tag="KPTS_RANSAC"
723
  )
724
+ logger.info(f"Display matches done using: {time.time() - t1:.3f}s")
725
  t1 = time.time()
726
 
727
  # compute warp images
 
753
  tmp_state_cache,
754
  )
755
 
756
+ @spaces.GPU
757
  def run_matching(
758
+ image0: np.ndarray,
759
+ image1: np.ndarray,
760
+ match_threshold: float,
761
+ extract_max_keypoints: int,
762
+ keypoint_threshold: float,
763
+ key: str,
764
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
765
+ ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
766
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
767
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
768
+ choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
769
+ matcher_zoo: Dict[str, Any] = None,
770
+ force_resize: bool = False,
771
+ image_width: int = 640,
772
+ image_height: int = 480,
773
+ use_cached_model: bool = False,
774
  ) -> Tuple[
775
  np.ndarray,
776
  np.ndarray,
 
846
  else:
847
  matcher = get_model(match_conf)
848
  print('match_conf2', match_conf)
849
+ logger.info(f"Loading model using: {time.time() - t0:.3f}s")
850
  t1 = time.time()
851
 
852
  if model["dense"]:
 
899
  )
900
  pred = match_features.match_images(matcher, pred0, pred1)
901
  # print('pred', pred)
902
+ mconf = pred["mconf"]
903
  print('mconf', mconf.min(), mconf.max())
904
  del extractor
905
  # gr.Info(
906
  # f"Matching images done using: {time.time()-t1:.3f}s",
907
  # )
908
+ logger.info(f"Matching images done using: {time.time() - t1:.3f}s")
909
  t1 = time.time()
910
 
911
  # plot images with keypoints
 
932
  )
933
 
934
  # gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
935
+ logger.info(f"RANSAC matches done using: {time.time() - t1:.3f}s")
936
  t1 = time.time()
937
 
938
  # plot images with ransac matches
 
944
  pred, titles=titles, tag="KPTS_RANSAC"
945
  )
946
  # gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
947
+ logger.info(f"Display matches done using: {time.time() - t1:.3f}s")
948
 
949
  t1 = time.time()
950
  # plot wrapped images
 
956
  )
957
  plt.close("all")
958
  # gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
959
+ logger.info(f"TOTAL time: {time.time() - t0:.3f}s")
960
 
961
  state_cache = pred
962
  state_cache["num_matches_raw"] = num_matches_raw