Skip to content

Commit 9320100

Browse files
bottlerfacebook-github-bot
authored andcommitted
object_mask only if required
Summary: New function to check if a renderer needs the object mask. Reviewed By: davnov134 Differential Revision: D35254009 fbshipit-source-id: 4c99e8a1c0f6641d910eb32bfd6cfae9d3463d50
1 parent 2edb93d commit 9320100

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

pytorch3d/implicitron/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def forward(
397397
func.bind_args(**custom_args)
398398

399399
chunked_renderer_inputs = {}
400-
if fg_probability is not None:
400+
if fg_probability is not None and self.renderer.requires_object_mask():
401401
sampled_fb_prob = rend_utils.ndc_grid_sample(
402402
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
403403
)

pytorch3d/implicitron/models/renderer/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,15 @@ class BaseRenderer(ABC, ReplaceableBase):
7272
Base class for all Renderer implementations.
7373
"""
7474

75-
def __init__(self):
75+
def __init__(self) -> None:
7676
super().__init__()
7777

78+
def requires_object_mask(self) -> bool:
79+
"""
80+
Whether `forward` needs the object_mask.
81+
"""
82+
return False
83+
7884
@abstractmethod
7985
def forward(
8086
self,

pytorch3d/implicitron/models/renderer/sdf_renderer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def __post_init__(
4949

5050
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
5151

52+
def requires_object_mask(self) -> bool:
53+
return True
54+
5255
def forward(
5356
self,
5457
ray_bundle: RayBundle,

0 commit comments

Comments
 (0)