@@ -596,6 +596,7 @@ def __init__(
596
596
verts_uvs : Union [torch .Tensor , List [torch .Tensor ], Tuple [torch .Tensor ]],
597
597
padding_mode : str = "border" ,
598
598
align_corners : bool = True ,
599
+ sampling_mode : str = "bilinear" ,
599
600
) -> None :
600
601
"""
601
602
Textures are represented as a per mesh texture map and uv coordinates for each
@@ -613,6 +614,9 @@ def __init__(
613
614
indicate the centers of the edge pixels in the maps.
614
615
padding_mode: padding mode for outside grid values
615
616
("zeros", "border" or "reflection").
617
+ sampling_mode: type of interpolation used to sample the texture.
618
+ Corresponds to the mode parameter in PyTorch's
619
+ grid_sample ("nearest" or "bilinear").
616
620
617
621
The align_corners and padding_mode arguments correspond to the arguments
618
622
of the `grid_sample` torch function. There is an informative illustration of
@@ -641,6 +645,7 @@ def __init__(
641
645
"""
642
646
self .padding_mode = padding_mode
643
647
self .align_corners = align_corners
648
+ self .sampling_mode = sampling_mode
644
649
if isinstance (faces_uvs , (list , tuple )):
645
650
for fv in faces_uvs :
646
651
if fv .ndim != 2 or fv .shape [- 1 ] != 3 :
@@ -749,6 +754,9 @@ def clone(self) -> "TexturesUV":
749
754
self .maps_padded ().clone (),
750
755
self .faces_uvs_padded ().clone (),
751
756
self .verts_uvs_padded ().clone (),
757
+ align_corners = self .align_corners ,
758
+ padding_mode = self .padding_mode ,
759
+ sampling_mode = self .sampling_mode ,
752
760
)
753
761
if self ._maps_list is not None :
754
762
tex ._maps_list = [m .clone () for m in self ._maps_list ]
@@ -770,6 +778,9 @@ def detach(self) -> "TexturesUV":
770
778
self .maps_padded ().detach (),
771
779
self .faces_uvs_padded ().detach (),
772
780
self .verts_uvs_padded ().detach (),
781
+ align_corners = self .align_corners ,
782
+ padding_mode = self .padding_mode ,
783
+ sampling_mode = self .sampling_mode ,
773
784
)
774
785
if self ._maps_list is not None :
775
786
tex ._maps_list = [m .detach () for m in self ._maps_list ]
@@ -801,6 +812,7 @@ def __getitem__(self, index) -> "TexturesUV":
801
812
maps = maps ,
802
813
padding_mode = self .padding_mode ,
803
814
align_corners = self .align_corners ,
815
+ sampling_mode = self .sampling_mode ,
804
816
)
805
817
elif all (torch .is_tensor (f ) for f in [faces_uvs , verts_uvs , maps ]):
806
818
new_tex = self .__class__ (
@@ -809,6 +821,7 @@ def __getitem__(self, index) -> "TexturesUV":
809
821
maps = [maps ],
810
822
padding_mode = self .padding_mode ,
811
823
align_corners = self .align_corners ,
824
+ sampling_mode = self .sampling_mode ,
812
825
)
813
826
else :
814
827
raise ValueError ("Not all values are provided in the correct format" )
@@ -889,6 +902,7 @@ def extend(self, N: int) -> "TexturesUV":
889
902
verts_uvs = new_props ["verts_uvs_padded" ],
890
903
padding_mode = self .padding_mode ,
891
904
align_corners = self .align_corners ,
905
+ sampling_mode = self .sampling_mode ,
892
906
)
893
907
894
908
new_tex ._num_faces_per_mesh = new_props ["_num_faces_per_mesh" ]
@@ -966,6 +980,7 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
966
980
texels = F .grid_sample (
967
981
texture_maps ,
968
982
pixel_uvs ,
983
+ mode = self .sampling_mode ,
969
984
align_corners = self .align_corners ,
970
985
padding_mode = self .padding_mode ,
971
986
)
@@ -1003,6 +1018,7 @@ def faces_verts_textures_packed(self) -> torch.Tensor:
1003
1018
textures = F .grid_sample (
1004
1019
texture_maps ,
1005
1020
faces_verts_uvs ,
1021
+ mode = self .sampling_mode ,
1006
1022
align_corners = self .align_corners ,
1007
1023
padding_mode = self .padding_mode ,
1008
1024
) # NxCxmax(Fi)x3
@@ -1060,6 +1076,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
1060
1076
faces_uvs = faces_uvs_list ,
1061
1077
padding_mode = self .padding_mode ,
1062
1078
align_corners = self .align_corners ,
1079
+ sampling_mode = self .sampling_mode ,
1063
1080
)
1064
1081
new_tex ._num_faces_per_mesh = num_faces_per_mesh
1065
1082
return new_tex
@@ -1227,6 +1244,7 @@ def join_scene(self) -> "TexturesUV":
1227
1244
faces_uvs = [torch .cat (faces_uvs_merged )],
1228
1245
align_corners = self .align_corners ,
1229
1246
padding_mode = self .padding_mode ,
1247
+ sampling_mode = self .sampling_mode ,
1230
1248
)
1231
1249
1232
1250
def centers_for_image (self , index : int ) -> torch .Tensor :
@@ -1259,6 +1277,7 @@ def centers_for_image(self, index: int) -> torch.Tensor:
1259
1277
torch .flip (coords .to (texture_image ), [2 ]),
1260
1278
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
1261
1279
verts_uvs [:, None ] * 2.0 - 1 ,
1280
+ mode = self .sampling_mode ,
1262
1281
align_corners = self .align_corners ,
1263
1282
padding_mode = self .padding_mode ,
1264
1283
).cpu ()
0 commit comments