dmytromishkin commited on
Commit
d7cb5e4
·
1 Parent(s): 5cd2bb7

Cleaned-up, and added diameter-based cv cost

Browse files
Files changed (1) hide show
  1. hoho/wed.py +24 -46
hoho/wed.py CHANGED
@@ -3,12 +3,6 @@ from scipy.optimize import linear_sum_assignment
3
  import numpy as np
4
 
5
 
6
- def zeromean_normalize(vertices):
7
- vertices = np.array(vertices)
8
- vertices = vertices - vertices.mean(axis=0)
9
- vertices = vertices / (1e-6 + np.linalg.norm(vertices, axis=1)[:, None]) # project all verts to sphere (not what we meant)
10
- return vertices
11
-
12
  def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
13
  mu_target = target_verts.mean(axis=0)
14
  mu_in = verts_to_transform.mean(axis=0)
@@ -34,52 +28,38 @@ def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
34
  return transformed_verts
35
 
36
 
37
- def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1000.0, ce=1.0, normalized=True, prenorm=False, preregister=True, register=False, single_scale=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  pd_vertices = np.array(pd_vertices)
39
  gt_vertices = np.array(gt_vertices)
40
 
 
 
 
 
 
 
41
  # Step 0: Prenormalize / preregister
42
- if prenorm:
43
- pd_vertices = zeromean_normalize(pd_vertices)
44
- gt_vertices = zeromean_normalize(gt_vertices)
45
-
46
  if preregister:
47
  pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
48
 
49
 
50
  pd_edges = np.array(pd_edges)
51
- gt_edges = np.array(gt_edges)
52
-
53
-
54
- # Step 0.5: Register
55
- if register:
56
- # find the optimal rotation, translation, and scale
57
- from scipy.spatial.transform import Rotation as R
58
- from scipy.optimize import minimize
59
-
60
- def transform(x, pd_vertices):
61
- # x is a 7-element vector, first 3 elements are the rotation vector, next 3 elements are the translation vector, finally scale
62
- rotation = R.from_rotvec(x[:3])
63
- translation = x[3:6]
64
- scale = x[6]
65
- return scale * rotation.apply(pd_vertices) + translation
66
-
67
- def cost_function(x, pd_vertices, gt_vertices):
68
- pd_vertices_transformed = transform(x, pd_vertices)
69
- distances = cdist(pd_vertices_transformed, gt_vertices, metric='euclidean')
70
- row_ind, col_ind = linear_sum_assignment(distances)
71
- translation_costs = np.sum(distances[row_ind, col_ind])
72
-
73
- return translation_costs
74
-
75
- x0 = np.array([0, 0, 0, 0, 0, 0, 1])
76
- # minimize subject to scale > 1e-6
77
- # res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), constraints={'type': 'ineq', 'fun': lambda x: x[6] - 1e-6})
78
- res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), bounds=[(-np.pi, np.pi), (-np.pi, np.pi), (-np.pi, np.pi), (-500, 500), (-500, 500), (-500, 500), (0.1, 3)])
79
- # print("scale:", res.x)
80
-
81
- pd_vertices = transform(res.x, pd_vertices)
82
-
83
 
84
  # Step 1: Bipartite Matching
85
  distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
@@ -106,7 +86,6 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1000.0, ce=1.0,
106
  # Delete edges not in ground truth
107
  edges_to_delete = pd_edges_set - gt_edges_set
108
 
109
- #deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete)
110
  vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
111
  deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)
112
 
@@ -117,8 +96,7 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1000.0, ce=1.0,
117
 
118
  # Step 5: Calculation of WED
119
  WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
120
- # print("translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs")
121
- # print(translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs)
122
 
123
  if normalized:
124
  total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
 
3
  import numpy as np
4
 
5
 
 
 
 
 
 
 
6
  def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
7
  mu_target = target_verts.mean(axis=0)
8
  mu_in = verts_to_transform.mean(axis=0)
 
28
  return transformed_verts
29
 
30
 
31
+ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1, ce=1.0, normalized=True, preregister=True, single_scale=True):
32
+ '''The function computes the Weighted Edge Distance (WED) between two graphs.
33
+ pd_vertices: list of predicted vertices
34
+ pd_edges: list of predicted edges
35
+ gt_vertices: list of ground truth vertices
36
+ gt_edges: list of ground truth edges
37
+ cv: vertex cost
38
+ ce: edge cost
39
+ normalized: if True, the WED is normalized by the total length of the ground truth edges
40
+ preregister: if True, the predicted vertices are pre-registered to the ground truth vertices
41
+ '''
42
+
43
+ # vertex coordinates are in centimeters, so cv and ce are set to 100.0 and 1.0 respectively.
44
+ # This means the missing a vertex is equivanlent predicting it 1 meters off,
45
+ # and that is the same as cv and ce equal to 1.0, if GT is in meters
46
+
47
  pd_vertices = np.array(pd_vertices)
48
  gt_vertices = np.array(gt_vertices)
49
 
50
+ diameter = cdist(gt_vertices, gt_vertices).max()
51
+
52
+ if cv < 0:
53
+ cv = diameter / 4.0
54
+ # Cost of addining or deleting a vertex is set to 1/4 of the diameter of the ground truth mesh
55
+
56
  # Step 0: Prenormalize / preregister
 
 
 
 
57
  if preregister:
58
  pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
59
 
60
 
61
  pd_edges = np.array(pd_edges)
62
+ gt_edges = np.array(gt_edges)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # Step 1: Bipartite Matching
65
  distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
 
86
  # Delete edges not in ground truth
87
  edges_to_delete = pd_edges_set - gt_edges_set
88
 
 
89
  vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
90
  deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)
91
 
 
96
 
97
  # Step 5: Calculation of WED
98
  WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
99
+
 
100
 
101
  if normalized:
102
  total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()