@@ -24,6 +24,21 @@ def setUp(self) -> None:
24
24
super ().setUp ()
25
25
torch .manual_seed (42 )
26
26
27
+ @classmethod
28
+ def _generate_epnp_test_from_2d (cls , y ):
29
+ """
30
+ Instantiate random x_world, x_cam, R, T given a set of input
31
+ 2D projections y.
32
+ """
33
+ batch_size = y .shape [0 ]
34
+ x_cam = torch .cat ((y , torch .rand_like (y [:, :, :1 ]) * 2.0 + 3.5 ), dim = 2 )
35
+ x_cam [:, :, :2 ] *= x_cam [:, :, 2 :] # unproject
36
+ R = rotation_conversions .random_rotations (batch_size ).to (y )
37
+ T = torch .randn_like (R [:, :1 , :])
38
+ T [:, :, 2 ] = (T [:, :, 2 ] + 3.0 ).clamp (2.0 )
39
+ x_world = torch .matmul (x_cam - T , R .transpose (1 , 2 ))
40
+ return x_cam , x_world , R , T
41
+
27
42
def _run_and_print (self , x_world , y , R , T , print_stats , skip_q , check_output = False ):
28
43
sol = perspective_n_points .efficient_pnp (
29
44
x_world , y .expand_as (x_world [:, :, :2 ]), skip_quadratic_eq = skip_q
@@ -45,16 +60,16 @@ def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=Fal
45
60
)
46
61
47
62
self .assertClose (err_2d , sol .err_2d , msg = assert_msg )
48
- self .assertTrue ((err_2d < 1e -4 ).all (), msg = assert_msg )
63
+ self .assertTrue ((err_2d < 5e -4 ).all (), msg = assert_msg )
49
64
50
65
def norm_fn (t ):
51
66
return t .norm (dim = - 1 )
52
67
53
68
self .assertNormsClose (
54
- T , sol .T [:, None , :], rtol = 3e -3 , norm_fn = norm_fn , msg = assert_msg
69
+ T , sol .T [:, None , :], rtol = 4e -3 , norm_fn = norm_fn , msg = assert_msg
55
70
)
56
71
self .assertNormsClose (
57
- R_quat , R_est_quat , rtol = 3e-4 , norm_fn = norm_fn , msg = assert_msg
72
+ R_quat , R_est_quat , rtol = 3e-3 , norm_fn = norm_fn , msg = assert_msg
58
73
)
59
74
60
75
if print_stats :
@@ -71,12 +86,9 @@ def norm_fn(t):
71
86
print ("T_hat | T_gt\n " , T_gt )
72
87
73
88
def _testcase_from_2d (self , y , print_stats , benchmark , skip_q = False ):
74
- x_cam = torch .cat ((y , torch .rand_like (y [:, :1 ]) * 2.0 + 3.5 ), dim = 1 )
75
- x_cam [:, :2 ] *= x_cam [:, 2 :] # unproject
76
-
77
- R = rotation_conversions .random_rotations (16 ).to (y )
78
- T = torch .randn_like (R [:, :1 , :])
79
- x_world = torch .matmul (x_cam - T , R .transpose (1 , 2 ))
89
+ x_cam , x_world , R , T = TestPerspectiveNPoints ._generate_epnp_test_from_2d (
90
+ y [None ].repeat (16 , 1 , 1 )
91
+ )
80
92
81
93
if print_stats :
82
94
print ("Run without noise" )
@@ -129,3 +141,45 @@ def test_perspective_n_points(self, print_stats=False):
129
141
benchmark = False ,
130
142
skip_q = skip_q ,
131
143
)
144
+
145
+ def test_weighted_perspective_n_points (self , batch_size = 16 , num_pts = 200 ):
146
+ # instantiate random x_world and y
147
+ y = torch .randn ((batch_size , num_pts , 2 )).cuda () / 3.0
148
+ x_cam , x_world , R , T = TestPerspectiveNPoints ._generate_epnp_test_from_2d (y )
149
+
150
+ # randomly drop 50% of the rows
151
+ weights = (torch .rand_like (x_world [:, :, 0 ]) > 0.5 ).float ()
152
+
153
+ # make sure we retain at least 6 points for each case
154
+ weights [:, :6 ] = 1.0
155
+
156
+ # fill ignored y with trash to ensure that we get different
157
+ # solution in case the weighting is wrong
158
+ y = y + (1 - weights [:, :, None ]) * 100.0
159
+
160
+ def norm_fn (t ):
161
+ return t .norm (dim = - 1 )
162
+
163
+ for skip_quadratic_eq in (True , False ):
164
+ # get the solution for the 0/1 weighted case
165
+ sol = perspective_n_points .efficient_pnp (
166
+ x_world , y , skip_quadratic_eq = skip_quadratic_eq , weights = weights
167
+ )
168
+ sol_R_quat = rotation_conversions .matrix_to_quaternion (sol .R )
169
+ sol_T = sol .T
170
+
171
+ # check that running only on points with non-zero weights ends in the
172
+ # same place as running the 0/1 weighted version
173
+ for i in range (batch_size ):
174
+ ok = weights [i ] > 0
175
+ x_world_ok = x_world [i , ok ][None ]
176
+ y_ok = y [i , ok ][None ]
177
+ sol_ok = perspective_n_points .efficient_pnp (
178
+ x_world_ok , y_ok , skip_quadratic_eq = False
179
+ )
180
+ R_est_quat_ok = rotation_conversions .matrix_to_quaternion (sol_ok .R )
181
+
182
+ self .assertNormsClose (sol_T [i ], sol_ok .T [0 ], rtol = 3e-3 , norm_fn = norm_fn )
183
+ self .assertNormsClose (
184
+ sol_R_quat [i ], R_est_quat_ok [0 ], rtol = 3e-4 , norm_fn = norm_fn
185
+ )
0 commit comments