@@ -159,12 +159,13 @@ def test_init_transformation(self, batch_size=10):
159
159
self .assertClose (s_init , s , atol = atol )
160
160
self .assertClose (Xt_init , Xt , atol = atol )
161
161
162
- def test_heterogeneous_inputs (self , batch_size = 10 ):
162
+ def test_heterogeneous_inputs (self , batch_size = 7 ):
163
163
"""
164
164
Tests whether we get the same result when running ICP on
165
165
a set of randomly-sized Pointclouds and on their padded versions.
166
166
"""
167
167
168
+ torch .manual_seed (4 )
168
169
device = torch .device ("cuda:0" )
169
170
170
171
for estimate_scale in (True , False ):
@@ -501,7 +502,6 @@ def test_corresponding_points_alignment(self, batch_size=10):
501
502
- use_pointclouds ... If True, passes the Pointclouds objects
502
503
to corresponding_points_alignment.
503
504
"""
504
- self .skipTest ("Temporarily disabled pending investigation" )
505
505
# run this for several different point cloud sizes
506
506
for n_points in (100 , 3 , 2 , 1 ):
507
507
# run this for several different dimensionalities
@@ -640,7 +640,10 @@ def align_and_get_mse(weights_):
640
640
if reflect and not allow_reflection :
641
641
# check that all rotations have det=1
642
642
self ._assert_all_close (
643
- torch .det (R_est ), R_est .new_ones (batch_size ), assert_error_message
643
+ torch .det (R_est ),
644
+ R_est .new_ones (batch_size ),
645
+ assert_error_message ,
646
+ atol = 2e-5 ,
644
647
)
645
648
646
649
else :
@@ -665,13 +668,13 @@ def align_and_get_mse(weights_):
665
668
desired_det = R_est .new_ones (batch_size )
666
669
if reflect :
667
670
desired_det *= - 1.0
668
- self ._assert_all_close (torch .det (R_est ), desired_det , msg , w )
671
+ self ._assert_all_close (torch .det (R_est ), desired_det , msg , w , atol = 2e-5 )
669
672
670
673
# check that the transformed point cloud
671
674
# X matches X_t
672
675
X_t_est = _apply_pcl_transformation (X , R_est , T_est , s = s_est )
673
676
self ._assert_all_close (
674
- X_t , X_t_est , assert_error_message , w [:, None , None ], atol = 1e -5
677
+ X_t , X_t_est , assert_error_message , w [:, None , None ], atol = 2e -5
675
678
)
676
679
677
680
def _assert_all_close (self , a_ , b_ , err_message , weights = None , atol = 1e-6 ):
0 commit comments