Skip to content

Commit 1fb97f9

Browse files
gkioxarifacebook-github-bot
authored andcommitted
update padded in meshes
Summary: Three changes to Meshes 1. `num_verts_per_mesh` and `num_faces_per_mesh` are assigned at construction time and are returned without the need for `compute_packed` 2. `update_padded` updates `verts_padded` and shallow copies faces list and faces_padded and existing attributes from construction. 3. `padded_to_packed_idx` does not need `compute_packed` Reviewed By: nikhilaravi Differential Revision: D21653674 fbshipit-source-id: dc6815a2e2a925fe4a834fe357919da2b2c14527
1 parent ae68a54 commit 1fb97f9

File tree

2 files changed

+201
-21
lines changed

2 files changed

+201
-21
lines changed

pytorch3d/structures/meshes.py

Lines changed: 107 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -462,9 +462,9 @@ def verts_list(self):
462462
assert (
463463
self._verts_padded is not None
464464
), "verts_padded is required to compute verts_list."
465-
self._verts_list = [
466-
v[0] for v in self._verts_padded.split([1] * self._N, 0)
467-
]
465+
self._verts_list = struct_utils.padded_to_list(
466+
self._verts_padded, self.num_verts_per_mesh().tolist()
467+
)
468468
return self._verts_list
469469

470470
def faces_list(self):
@@ -478,10 +478,9 @@ def faces_list(self):
478478
assert (
479479
self._faces_padded is not None
480480
), "faces_padded is required to compute faces_list."
481-
self._faces_list = []
482-
for i in range(self._N):
483-
valid = self._faces_padded[i].gt(-1).all(1)
484-
self._faces_list.append(self._faces_padded[i, valid, :])
481+
self._faces_list = struct_utils.padded_to_list(
482+
self._faces_padded, self.num_faces_per_mesh().tolist()
483+
)
485484
return self._faces_list
486485

487486
def verts_packed(self):
@@ -525,7 +524,6 @@ def num_verts_per_mesh(self):
525524
Returns:
526525
1D tensor of sizes.
527526
"""
528-
self._compute_packed()
529527
return self._num_verts_per_mesh
530528

531529
def faces_packed(self):
@@ -590,7 +588,6 @@ def num_faces_per_mesh(self):
590588
Returns:
591589
1D tensor of sizes.
592590
"""
593-
self._compute_packed()
594591
return self._num_faces_per_mesh
595592

596593
def edges_packed(self):
@@ -664,14 +661,13 @@ def verts_padded_to_packed_idx(self):
664661
Returns:
665662
1D tensor of indices.
666663
"""
667-
self._compute_packed()
668664
if self._verts_padded_to_packed_idx is not None:
669665
return self._verts_padded_to_packed_idx
670666

671667
self._verts_padded_to_packed_idx = torch.cat(
672668
[
673669
torch.arange(v, dtype=torch.int64, device=self.device) + i * self._V
674-
for (i, v) in enumerate(self._num_verts_per_mesh)
670+
for (i, v) in enumerate(self.num_verts_per_mesh())
675671
],
676672
dim=0,
677673
)
@@ -862,8 +858,8 @@ def _compute_padded(self, refresh: bool = False):
862858
):
863859
return
864860

865-
verts_list = self._verts_list
866-
faces_list = self._faces_list
861+
verts_list = self.verts_list()
862+
faces_list = self.faces_list()
867863
assert (
868864
faces_list is not None and verts_list is not None
869865
), "faces_list and verts_list arguments are required"
@@ -943,13 +939,15 @@ def _compute_packed(self, refresh: bool = False):
943939

944940
verts_list_to_packed = struct_utils.list_to_packed(verts_list)
945941
self._verts_packed = verts_list_to_packed[0]
946-
self._num_verts_per_mesh = verts_list_to_packed[1]
942+
if not torch.allclose(self.num_verts_per_mesh(), verts_list_to_packed[1]):
943+
raise ValueError("The number of verts per mesh should be consistent.")
947944
self._mesh_to_verts_packed_first_idx = verts_list_to_packed[2]
948945
self._verts_packed_to_mesh_idx = verts_list_to_packed[3]
949946

950947
faces_list_to_packed = struct_utils.list_to_packed(faces_list)
951948
faces_packed = faces_list_to_packed[0]
952-
self._num_faces_per_mesh = faces_list_to_packed[1]
949+
if not torch.allclose(self.num_faces_per_mesh(), faces_list_to_packed[1]):
950+
raise ValueError("The number of faces per mesh should be consistent.")
953951
self._mesh_to_faces_packed_first_idx = faces_list_to_packed[2]
954952
self._faces_packed_to_mesh_idx = faces_list_to_packed[3]
955953

@@ -1328,6 +1326,100 @@ def scale_verts(self, scale):
13281326
new_mesh = self.clone()
13291327
return new_mesh.scale_verts_(scale)
13301328

1329+
def update_padded(self, new_verts_padded):
1330+
"""
1331+
This function allows for an pdate of verts_padded without having to
1332+
explicitly convert it to the list representation for heterogeneous batches.
1333+
Returns a Meshes structure with updated padded tensors and copies of the
1334+
auxiliary tensors at construction time.
1335+
It updates self._verts_padded with new_verts_padded, and does a
1336+
shallow copy of (faces_padded, faces_list, num_verts_per_mesh, num_faces_per_mesh).
1337+
If packed representations are computed in self, they are updated as well.
1338+
1339+
Args:
1340+
new_points_padded: FloatTensor of shape (N, V, 3)
1341+
1342+
Returns:
1343+
Meshes with updated padded representations
1344+
"""
1345+
1346+
def check_shapes(x, size):
1347+
if x.shape[0] != size[0]:
1348+
raise ValueError("new values must have the same batch dimension.")
1349+
if x.shape[1] != size[1]:
1350+
raise ValueError("new values must have the same number of points.")
1351+
if x.shape[2] != size[2]:
1352+
raise ValueError("new values must have the same dimension.")
1353+
1354+
check_shapes(new_verts_padded, [self._N, self._V, 3])
1355+
1356+
new = self.__class__(verts=new_verts_padded, faces=self.faces_padded())
1357+
1358+
if new._N != self._N or new._V != self._V or new._F != self._F:
1359+
raise ValueError("Inconsistent sizes after construction.")
1360+
1361+
# overwrite the equisized flag
1362+
new.equisized = self.equisized
1363+
1364+
# overwrite textures if any
1365+
new.textures = self.textures
1366+
1367+
# copy auxiliary tensors
1368+
copy_tensors = ["_num_verts_per_mesh", "_num_faces_per_mesh", "valid"]
1369+
1370+
for k in copy_tensors:
1371+
v = getattr(self, k)
1372+
if torch.is_tensor(v):
1373+
setattr(new, k, v) # shallow copy
1374+
1375+
# shallow copy of faces_list if any, st new.faces_list()
1376+
# does not re-compute from _faces_padded
1377+
new._faces_list = self._faces_list
1378+
1379+
# update verts/faces packed if they are computed in self
1380+
if self._verts_packed is not None:
1381+
copy_tensors = [
1382+
"_faces_packed",
1383+
"_verts_packed_to_mesh_idx",
1384+
"_faces_packed_to_mesh_idx",
1385+
"_mesh_to_verts_packed_first_idx",
1386+
"_mesh_to_faces_packed_first_idx",
1387+
]
1388+
for k in copy_tensors:
1389+
v = getattr(self, k)
1390+
assert torch.is_tensor(v)
1391+
setattr(new, k, v) # shallow copy
1392+
# update verts_packed
1393+
pad_to_packed = self.verts_padded_to_packed_idx()
1394+
new_verts_packed = new_verts_padded.reshape(-1, 3)[pad_to_packed, :]
1395+
new._verts_packed = new_verts_packed
1396+
new._verts_padded_to_packed_idx = pad_to_packed
1397+
1398+
# update edges packed if they are computed in self
1399+
if self._edges_packed is not None:
1400+
copy_tensors = [
1401+
"_edges_packed",
1402+
"_edges_packed_to_mesh_idx",
1403+
"_mesh_to_edges_packed_first_idx",
1404+
"_faces_packed_to_edges_packed",
1405+
"_num_edges_per_mesh",
1406+
]
1407+
for k in copy_tensors:
1408+
v = getattr(self, k)
1409+
assert torch.is_tensor(v)
1410+
setattr(new, k, v) # shallow copy
1411+
1412+
# update laplacian if it is compute in self
1413+
if self._laplacian_packed is not None:
1414+
new._laplacian_packed = self._laplacian_packed
1415+
1416+
assert new._verts_list is None
1417+
assert new._verts_normals_packed is None
1418+
assert new._faces_normals_packed is None
1419+
assert new._faces_areas_packed is None
1420+
1421+
return new
1422+
13311423
# TODO(nikhilar) Move function to utils file.
13321424
def get_bounding_boxes(self):
13331425
"""

tests/test_meshes.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,16 +342,11 @@ def test_clone(self):
342342

343343
# Modify tensors in both meshes.
344344
new_mesh._verts_list[0] = new_mesh._verts_list[0] * 5
345-
mesh._num_verts_per_mesh = torch.randint_like(
346-
mesh.num_verts_per_mesh(), high=10
347-
)
345+
348346
# Check cloned and original Meshes objects do not share tensors.
349347
self.assertFalse(
350348
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])
351349
)
352-
self.assertFalse(
353-
torch.allclose(mesh.num_verts_per_mesh(), new_mesh.num_verts_per_mesh())
354-
)
355350
self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed())
356351
self.assertSeparate(new_mesh.verts_padded(), mesh.verts_padded())
357352
self.assertSeparate(new_mesh.faces_packed(), mesh.faces_packed())
@@ -690,6 +685,99 @@ def test_split_mesh(self):
690685
with self.assertRaises(ValueError):
691686
mesh.split(split_sizes)
692687

688+
def test_update_padded(self):
689+
# Define the test mesh object either as a list or tensor of faces/verts.
690+
N = 10
691+
for lists_to_tensors in (False, True):
692+
for force in (True, False):
693+
mesh = TestMeshes.init_mesh(
694+
N, 100, 300, lists_to_tensors=lists_to_tensors
695+
)
696+
num_verts_per_mesh = mesh.num_verts_per_mesh()
697+
if force:
698+
# force mesh to have computed attributes
699+
mesh.verts_packed()
700+
mesh.edges_packed()
701+
mesh.laplacian_packed()
702+
mesh.faces_areas_packed()
703+
704+
new_verts = torch.rand((mesh._N, mesh._V, 3), device=mesh.device)
705+
new_verts_list = [
706+
new_verts[i, : num_verts_per_mesh[i]] for i in range(N)
707+
]
708+
new_mesh = mesh.update_padded(new_verts)
709+
710+
# check the attributes assigned at construction time
711+
self.assertEqual(new_mesh._N, mesh._N)
712+
self.assertEqual(new_mesh._F, mesh._F)
713+
self.assertEqual(new_mesh._V, mesh._V)
714+
self.assertEqual(new_mesh.equisized, mesh.equisized)
715+
self.assertTrue(all(new_mesh.valid == mesh.valid))
716+
self.assertNotSeparate(
717+
new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh()
718+
)
719+
self.assertClose(
720+
new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh()
721+
)
722+
self.assertNotSeparate(
723+
new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh()
724+
)
725+
self.assertClose(
726+
new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh()
727+
)
728+
729+
# check that the following attributes are not assigned
730+
self.assertIsNone(new_mesh._verts_list)
731+
self.assertIsNone(new_mesh._faces_areas_packed)
732+
self.assertIsNone(new_mesh._faces_normals_packed)
733+
self.assertIsNone(new_mesh._verts_normals_packed)
734+
735+
check_tensors = [
736+
"_faces_packed",
737+
"_verts_packed_to_mesh_idx",
738+
"_faces_packed_to_mesh_idx",
739+
"_mesh_to_verts_packed_first_idx",
740+
"_mesh_to_faces_packed_first_idx",
741+
"_edges_packed",
742+
"_edges_packed_to_mesh_idx",
743+
"_mesh_to_edges_packed_first_idx",
744+
"_faces_packed_to_edges_packed",
745+
"_num_edges_per_mesh",
746+
]
747+
for k in check_tensors:
748+
v = getattr(new_mesh, k)
749+
if not force:
750+
self.assertIsNone(v)
751+
else:
752+
v_old = getattr(mesh, k)
753+
self.assertNotSeparate(v, v_old)
754+
self.assertClose(v, v_old)
755+
756+
# check verts/faces padded
757+
self.assertClose(new_mesh.verts_padded(), new_verts)
758+
self.assertNotSeparate(new_mesh.verts_padded(), new_verts)
759+
self.assertClose(new_mesh.faces_padded(), mesh.faces_padded())
760+
self.assertNotSeparate(new_mesh.faces_padded(), mesh.faces_padded())
761+
# check verts/faces list
762+
for i in range(N):
763+
self.assertNotSeparate(
764+
new_mesh.faces_list()[i], mesh.faces_list()[i]
765+
)
766+
self.assertClose(new_mesh.faces_list()[i], mesh.faces_list()[i])
767+
self.assertSeparate(new_mesh.verts_list()[i], mesh.verts_list()[i])
768+
self.assertClose(new_mesh.verts_list()[i], new_verts_list[i])
769+
# check verts/faces packed
770+
self.assertClose(new_mesh.verts_packed(), torch.cat(new_verts_list))
771+
self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed())
772+
self.assertClose(new_mesh.faces_packed(), mesh.faces_packed())
773+
# check pad_to_packed
774+
self.assertClose(
775+
new_mesh.verts_padded_to_packed_idx(),
776+
mesh.verts_padded_to_packed_idx(),
777+
)
778+
# check edges
779+
self.assertClose(new_mesh.edges_packed(), mesh.edges_packed())
780+
693781
def test_get_mesh_verts_faces(self):
694782
device = torch.device("cuda:0")
695783
verts_list = []

0 commit comments

Comments
 (0)