@@ -18,7 +18,9 @@ class _knn_points(Function):
18
18
"""
19
19
20
20
@staticmethod
21
- def forward (ctx , p1 , p2 , lengths1 , lengths2 , K , version ):
21
+ def forward (
22
+ ctx , p1 , p2 , lengths1 , lengths2 , K , version , return_sorted : bool = True
23
+ ):
22
24
"""
23
25
K-Nearest neighbors on point clouds.
24
26
@@ -36,6 +38,8 @@ def forward(ctx, p1, p2, lengths1, lengths2, K, version):
36
38
K: Integer giving the number of nearest neighbors to return.
37
39
version: Which KNN implementation to use in the backend. If version=-1,
38
40
the correct implementation is selected based on the shapes of the inputs.
41
+ return_sorted: (bool) whether to return the nearest neighbors sorted in
42
+ ascending order of distance.
39
43
40
44
Returns:
41
45
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -52,7 +56,7 @@ def forward(ctx, p1, p2, lengths1, lengths2, K, version):
52
56
idx , dists = _C .knn_points_idx (p1 , p2 , lengths1 , lengths2 , K , version )
53
57
54
58
# sort KNN in ascending order if K > 1
55
- if K > 1 :
59
+ if K > 1 and return_sorted :
56
60
if lengths2 .min () < K :
57
61
P1 = p1 .shape [1 ]
58
62
mask = lengths2 [:, None ] <= torch .arange (K , device = dists .device )[None ]
@@ -84,7 +88,7 @@ def backward(ctx, grad_dists, grad_idx):
84
88
grad_p1 , grad_p2 = _C .knn_points_backward (
85
89
p1 , p2 , lengths1 , lengths2 , idx , grad_dists
86
90
)
87
- return grad_p1 , grad_p2 , None , None , None , None
91
+ return grad_p1 , grad_p2 , None , None , None , None , None
88
92
89
93
90
94
def knn_points (
@@ -95,6 +99,7 @@ def knn_points(
95
99
K : int = 1 ,
96
100
version : int = - 1 ,
97
101
return_nn : bool = False ,
102
+ return_sorted : bool = True ,
98
103
):
99
104
"""
100
105
K-Nearest neighbors on point clouds.
@@ -113,7 +118,9 @@ def knn_points(
113
118
K: Integer giving the number of nearest neighbors to return.
114
119
version: Which KNN implementation to use in the backend. If version=-1,
115
120
the correct implementation is selected based on the shapes of the inputs.
116
- return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
121
+ return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
122
+ return_sorted: (bool) whether to return the nearest neighbors sorted in
123
+ ascending order of distance.
117
124
118
125
Returns:
119
126
dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -158,7 +165,9 @@ def knn_points(
158
165
lengths2 = torch .full ((p1 .shape [0 ],), P2 , dtype = torch .int64 , device = p1 .device )
159
166
160
167
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
161
- p1_dists , p1_idx = _knn_points .apply (p1 , p2 , lengths1 , lengths2 , K , version )
168
+ p1_dists , p1_idx = _knn_points .apply (
169
+ p1 , p2 , lengths1 , lengths2 , K , version , return_sorted
170
+ )
162
171
163
172
p2_nn = None
164
173
if return_nn :
0 commit comments