Skip to content

Commit 806ca36

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
making sorting for K >1 optional in KNN points function
Summary: Added `sorted` argument to the `knn_points` function. This came up during the benchmarking against Faiss - sorting added extra memory usage. Match the memory usage of Faiss by making sorting optional. Reviewed By: bottler, gkioxari Differential Revision: D22329070 fbshipit-source-id: 0828ff9b48eefce99ce1f60089389f6885d03139
1 parent dd4a35c commit 806ca36

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

pytorch3d/ops/knn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class _knn_points(Function):
1818
"""
1919

2020
@staticmethod
21-
def forward(ctx, p1, p2, lengths1, lengths2, K, version):
21+
def forward(
22+
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
23+
):
2224
"""
2325
K-Nearest neighbors on point clouds.
2426
@@ -36,6 +38,8 @@ def forward(ctx, p1, p2, lengths1, lengths2, K, version):
3638
K: Integer giving the number of nearest neighbors to return.
3739
version: Which KNN implementation to use in the backend. If version=-1,
3840
the correct implementation is selected based on the shapes of the inputs.
41+
return_sorted: (bool) whether to return the nearest neighbors sorted in
42+
ascending order of distance.
3943
4044
Returns:
4145
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -52,7 +56,7 @@ def forward(ctx, p1, p2, lengths1, lengths2, K, version):
5256
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
5357

5458
# sort KNN in ascending order if K > 1
55-
if K > 1:
59+
if K > 1 and return_sorted:
5660
if lengths2.min() < K:
5761
P1 = p1.shape[1]
5862
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
@@ -84,7 +88,7 @@ def backward(ctx, grad_dists, grad_idx):
8488
grad_p1, grad_p2 = _C.knn_points_backward(
8589
p1, p2, lengths1, lengths2, idx, grad_dists
8690
)
87-
return grad_p1, grad_p2, None, None, None, None
91+
return grad_p1, grad_p2, None, None, None, None, None
8892

8993

9094
def knn_points(
@@ -95,6 +99,7 @@ def knn_points(
9599
K: int = 1,
96100
version: int = -1,
97101
return_nn: bool = False,
102+
return_sorted: bool = True,
98103
):
99104
"""
100105
K-Nearest neighbors on point clouds.
@@ -113,7 +118,9 @@ def knn_points(
113118
K: Integer giving the number of nearest neighbors to return.
114119
version: Which KNN implementation to use in the backend. If version=-1,
115120
the correct implementation is selected based on the shapes of the inputs.
116-
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
121+
return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
122+
return_sorted: (bool) whether to return the nearest neighbors sorted in
123+
ascending order of distance.
117124
118125
Returns:
119126
dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -158,7 +165,9 @@ def knn_points(
158165
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
159166

160167
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
161-
p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version)
168+
p1_dists, p1_idx = _knn_points.apply(
169+
p1, p2, lengths1, lengths2, K, version, return_sorted
170+
)
162171

163172
p2_nn = None
164173
if return_nn:

tests/test_knn.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
4949

5050
return _KNN(dists=dists, idx=idx, knn=None)
5151

52-
def _knn_vs_python_square_helper(self, device):
52+
def _knn_vs_python_square_helper(self, device, return_sorted):
5353
Ns = [1, 4]
5454
Ds = [3, 5, 8]
5555
P1s = [8, 24]
@@ -70,7 +70,24 @@ def _knn_vs_python_square_helper(self, device):
7070

7171
# forward
7272
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
73-
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
73+
out2 = knn_points(
74+
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
75+
)
76+
if K > 1 and not return_sorted:
77+
# check out2 is not sorted
78+
self.assertFalse(torch.allclose(out1[0], out2[0]))
79+
self.assertFalse(torch.allclose(out1[1], out2[1]))
80+
# now sort out2
81+
dists, idx, _ = out2
82+
if P2 < K:
83+
dists[..., P2:] = float("inf")
84+
dists, sort_idx = dists.sort(dim=2)
85+
dists[..., P2:] = 0
86+
else:
87+
dists, sort_idx = dists.sort(dim=2)
88+
idx = idx.gather(2, sort_idx)
89+
out2 = _KNN(dists, idx, None)
90+
7491
self.assertClose(out1[0], out2[0])
7592
self.assertTrue(torch.all(out1[1] == out2[1]))
7693

@@ -86,11 +103,13 @@ def _knn_vs_python_square_helper(self, device):
86103

87104
def test_knn_vs_python_square_cpu(self):
88105
device = torch.device("cpu")
89-
self._knn_vs_python_square_helper(device)
106+
self._knn_vs_python_square_helper(device, return_sorted=True)
90107

91108
def test_knn_vs_python_square_cuda(self):
92109
device = get_random_cuda_device()
93-
self._knn_vs_python_square_helper(device)
110+
# Check both cases where the output is sorted and unsorted
111+
self._knn_vs_python_square_helper(device, return_sorted=True)
112+
self._knn_vs_python_square_helper(device, return_sorted=False)
94113

95114
def _knn_vs_python_ragged_helper(self, device):
96115
Ns = [1, 4]

0 commit comments

Comments
 (0)