Skip to content

Commit 4c8338b

Browse files
khundmanfacebook-github-bot
authored andcommitted
Improve memory efficiency in VolumeSampler
Summary: Avoids use of `torch.cat` operation when rendering a volume by instead issuing multiple calls to `torch.nn.functional.grid_sample`. Density and color tensors can be large. Reviewed By: bottler Differential Revision: D40072399 fbshipit-source-id: eb4cd34f6171d54972bbf2877065f973db497de0
1 parent 0d8608b commit 4c8338b

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

pytorch3d/renderer/implicit/renderer.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -363,35 +363,40 @@ def forward(
363363
volumes_densities = self._volumes.densities()
364364
dim_density = volumes_densities.shape[1]
365365
volumes_features = self._volumes.features()
366-
# adjust the volumes_features variable in case we have a feature-less volume
367-
if volumes_features is None:
368-
dim_feature = 0
369-
data_to_sample = volumes_densities
370-
else:
371-
dim_feature = volumes_features.shape[1]
372-
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
373366

374367
# reshape to a size which grid_sample likes
375368
rays_points_local_flat = rays_points_local.view(
376369
rays_points_local.shape[0], -1, 1, 1, 3
377370
)
378371

379-
# run the grid sampler
380-
data_sampled = torch.nn.functional.grid_sample(
381-
data_to_sample,
372+
# run the grid sampler on the volumes densities
373+
rays_densities = torch.nn.functional.grid_sample(
374+
volumes_densities,
382375
rays_points_local_flat,
383376
align_corners=True,
384377
mode=self._sample_mode,
385378
)
386379

387-
# permute the dimensions & reshape after sampling
388-
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
389-
*rays_points_local.shape[:-1], data_sampled.shape[1]
380+
# permute the dimensions & reshape densities after sampling
381+
rays_densities = rays_densities.permute(0, 2, 3, 4, 1).view(
382+
*rays_points_local.shape[:-1], volumes_densities.shape[1]
390383
)
391384

392-
# split back to densities and features
393-
rays_densities, rays_features = data_sampled.split(
394-
[dim_density, dim_feature], dim=-1
395-
)
385+
# if features exist, run grid sampler again on the features densities
386+
if volumes_features is None:
387+
dim_feature = 0
388+
_, rays_features = rays_densities.split([dim_density, dim_feature], dim=-1)
389+
else:
390+
rays_features = torch.nn.functional.grid_sample(
391+
volumes_features,
392+
rays_points_local_flat,
393+
align_corners=True,
394+
mode=self._sample_mode,
395+
)
396+
397+
# permute the dimensions & reshape features after sampling
398+
rays_features = rays_features.permute(0, 2, 3, 4, 1).view(
399+
*rays_points_local.shape[:-1], volumes_features.shape[1]
400+
)
396401

397402
return rays_densities, rays_features

0 commit comments

Comments
 (0)