Skip to content

Commit 2603985

Browse files
committed
Update test to include non-uniform case
1 parent f124dc4 commit 2603985

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

pytorch3d/structures/meshes.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,20 +1559,24 @@ def volume_centroid(self):
15591559
"""
15601560
v_idxs = self.faces_padded().split([1, 1, 1], dim=-1)
15611561
verts = self.verts_padded()
1562-
1563-
v0, v1, v2 = [torch.gather(verts, 1, idx.expand(-1, -1, 3)) for idx in v_idxs]
1562+
valid = (self.faces_padded() != -1).all(dim=-1, keepdim=True)
1563+
1564+
v0, v1, v2 = [
1565+
torch.gather(
1566+
verts,
1567+
1,
1568+
idx.where(valid, torch.zeros_like(idx)).expand(-1, -1, 3),
1569+
).where(valid, torch.zeros_like(idx, dtype=verts.dtype))
1570+
for idx in v_idxs
1571+
]
15641572

15651573
tetra_center = (v0 + v1 + v2) / 4
15661574
signed_tetra_vol = (v0 * torch.cross(v1, v2, dim=-1)).sum(
15671575
dim=-1, keepdim=True
15681576
) / 6
15691577
denom = signed_tetra_vol.sum(dim=-2)
15701578
# clamp the denominator to prevent instability for degenerate meshes.
1571-
denom = torch.where(
1572-
denom < 0,
1573-
denom.clamp(max=-1e-5),
1574-
denom.clamp(min=1e-5)
1575-
)
1579+
denom = torch.where(denom < 0, denom.clamp(max=-1e-5), denom.clamp(min=1e-5))
15761580
return (tetra_center * signed_tetra_vol).sum(dim=-2) / denom
15771581

15781582
def submeshes(

tests/test_meshes.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,13 +1299,24 @@ def test_assigned_normals(self):
12991299
self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts))
13001300

13011301
def test_centroid(self):
1302+
meshes = init_simple_mesh()
1303+
# Check that it returns a valid value for multiple meshes with an inconsistent number
1304+
# of vertices
1305+
meshes.volume_centroid()
1306+
13021307
cube = init_cube_meshes()
1303-
self.assertClose(cube.volume_centroid(), torch.tensor([
1304-
[0.5] * 3,
1305-
[1.5] * 3,
1306-
[2.5] * 3,
1307-
[3.5] * 3,
1308-
]))
1308+
self.assertClose(
1309+
cube.volume_centroid(),
1310+
torch.tensor(
1311+
[
1312+
[0.5] * 3,
1313+
[1.5] * 3,
1314+
[2.5] * 3,
1315+
[3.5] * 3,
1316+
]
1317+
),
1318+
)
1319+
13091320
def test_submeshes(self):
13101321
empty_mesh = Meshes([], [])
13111322
# Four cubes with offsets [0, 1, 2, 3].

0 commit comments

Comments
 (0)