Skip to content

Commit d9f7095

Browse files
anadodikfacebook-github-bot
authored andcommitted
Adding the option to choose the texture sampling mode in TexturesUV.
Summary: This diff adds the `sample_mode` parameter to `TexturesUV` to control the interpolation mode during texture sampling. It simply gets forwarded to `torch.nn.funcitonal.grid_sample`. This option was requested in this [GitHub issue](#805). Reviewed By: patricklabatut Differential Revision: D32665185 fbshipit-source-id: ac0bc66a018bd4cb20d75fec2d7c11145dd20199
1 parent e4456db commit d9f7095

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

pytorch3d/renderer/mesh/textures.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ def __init__(
596596
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
597597
padding_mode: str = "border",
598598
align_corners: bool = True,
599+
sampling_mode: str = "bilinear",
599600
) -> None:
600601
"""
601602
Textures are represented as a per mesh texture map and uv coordinates for each
@@ -613,6 +614,9 @@ def __init__(
613614
indicate the centers of the edge pixels in the maps.
614615
padding_mode: padding mode for outside grid values
615616
("zeros", "border" or "reflection").
617+
sampling_mode: type of interpolation used to sample the texture.
618+
Corresponds to the mode parameter in PyTorch's
619+
grid_sample ("nearest" or "bilinear").
616620
617621
The align_corners and padding_mode arguments correspond to the arguments
618622
of the `grid_sample` torch function. There is an informative illustration of
@@ -641,6 +645,7 @@ def __init__(
641645
"""
642646
self.padding_mode = padding_mode
643647
self.align_corners = align_corners
648+
self.sampling_mode = sampling_mode
644649
if isinstance(faces_uvs, (list, tuple)):
645650
for fv in faces_uvs:
646651
if fv.ndim != 2 or fv.shape[-1] != 3:
@@ -749,6 +754,9 @@ def clone(self) -> "TexturesUV":
749754
self.maps_padded().clone(),
750755
self.faces_uvs_padded().clone(),
751756
self.verts_uvs_padded().clone(),
757+
align_corners=self.align_corners,
758+
padding_mode=self.padding_mode,
759+
sampling_mode=self.sampling_mode,
752760
)
753761
if self._maps_list is not None:
754762
tex._maps_list = [m.clone() for m in self._maps_list]
@@ -770,6 +778,9 @@ def detach(self) -> "TexturesUV":
770778
self.maps_padded().detach(),
771779
self.faces_uvs_padded().detach(),
772780
self.verts_uvs_padded().detach(),
781+
align_corners=self.align_corners,
782+
padding_mode=self.padding_mode,
783+
sampling_mode=self.sampling_mode,
773784
)
774785
if self._maps_list is not None:
775786
tex._maps_list = [m.detach() for m in self._maps_list]
@@ -801,6 +812,7 @@ def __getitem__(self, index) -> "TexturesUV":
801812
maps=maps,
802813
padding_mode=self.padding_mode,
803814
align_corners=self.align_corners,
815+
sampling_mode=self.sampling_mode,
804816
)
805817
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
806818
new_tex = self.__class__(
@@ -809,6 +821,7 @@ def __getitem__(self, index) -> "TexturesUV":
809821
maps=[maps],
810822
padding_mode=self.padding_mode,
811823
align_corners=self.align_corners,
824+
sampling_mode=self.sampling_mode,
812825
)
813826
else:
814827
raise ValueError("Not all values are provided in the correct format")
@@ -889,6 +902,7 @@ def extend(self, N: int) -> "TexturesUV":
889902
verts_uvs=new_props["verts_uvs_padded"],
890903
padding_mode=self.padding_mode,
891904
align_corners=self.align_corners,
905+
sampling_mode=self.sampling_mode,
892906
)
893907

894908
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
@@ -966,6 +980,7 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
966980
texels = F.grid_sample(
967981
texture_maps,
968982
pixel_uvs,
983+
mode=self.sampling_mode,
969984
align_corners=self.align_corners,
970985
padding_mode=self.padding_mode,
971986
)
@@ -1003,6 +1018,7 @@ def faces_verts_textures_packed(self) -> torch.Tensor:
10031018
textures = F.grid_sample(
10041019
texture_maps,
10051020
faces_verts_uvs,
1021+
mode=self.sampling_mode,
10061022
align_corners=self.align_corners,
10071023
padding_mode=self.padding_mode,
10081024
) # NxCxmax(Fi)x3
@@ -1060,6 +1076,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
10601076
faces_uvs=faces_uvs_list,
10611077
padding_mode=self.padding_mode,
10621078
align_corners=self.align_corners,
1079+
sampling_mode=self.sampling_mode,
10631080
)
10641081
new_tex._num_faces_per_mesh = num_faces_per_mesh
10651082
return new_tex
@@ -1227,6 +1244,7 @@ def join_scene(self) -> "TexturesUV":
12271244
faces_uvs=[torch.cat(faces_uvs_merged)],
12281245
align_corners=self.align_corners,
12291246
padding_mode=self.padding_mode,
1247+
sampling_mode=self.sampling_mode,
12301248
)
12311249

12321250
def centers_for_image(self, index: int) -> torch.Tensor:
@@ -1259,6 +1277,7 @@ def centers_for_image(self, index: int) -> torch.Tensor:
12591277
torch.flip(coords.to(texture_image), [2]),
12601278
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
12611279
verts_uvs[:, None] * 2.0 - 1,
1280+
mode=self.sampling_mode,
12621281
align_corners=self.align_corners,
12631282
padding_mode=self.padding_mode,
12641283
).cpu()

0 commit comments

Comments
 (0)