Skip to content

Commit 822afe0

Browse files
committed
implement tiled encode/decode
1 parent edd7880 commit 822afe0

File tree

2 files changed

+328
-6
lines changed

2 files changed

+328
-6
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 251 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,77 @@ def __init__(
730730
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
731731
)
732732

733+
self.temporal_compression_ratio = 2 ** sum(self.temperal_downsample)
734+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
735+
736+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
737+
# to perform decoding of a single video latent at a time.
738+
self.use_slicing = False
739+
740+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
741+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
742+
# intermediate tiles together, the memory requirement can be lowered.
743+
self.use_tiling = False
744+
745+
# The minimal tile height and width for spatial tiling to be used
746+
self.tile_sample_min_height = 256
747+
self.tile_sample_min_width = 256
748+
749+
# The minimal distance between two spatial tiles
750+
self.tile_sample_stride_height = 192
751+
self.tile_sample_stride_width = 192
752+
753+
def enable_tiling(
754+
self,
755+
tile_sample_min_height: Optional[int] = None,
756+
tile_sample_min_width: Optional[int] = None,
757+
tile_sample_stride_height: Optional[float] = None,
758+
tile_sample_stride_width: Optional[float] = None,
759+
) -> None:
760+
r"""
761+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
762+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
763+
processing larger images.
764+
765+
Args:
766+
tile_sample_min_height (`int`, *optional*):
767+
The minimum height required for a sample to be separated into tiles across the height dimension.
768+
tile_sample_min_width (`int`, *optional*):
769+
The minimum width required for a sample to be separated into tiles across the width dimension.
770+
tile_sample_stride_height (`int`, *optional*):
771+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
772+
no tiling artifacts produced across the height dimension.
773+
tile_sample_stride_width (`int`, *optional*):
774+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
775+
artifacts produced across the width dimension.
776+
"""
777+
self.use_tiling = True
778+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
779+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
780+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
781+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
782+
783+
def disable_tiling(self) -> None:
784+
r"""
785+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
786+
decoding in one step.
787+
"""
788+
self.use_tiling = False
789+
790+
def enable_slicing(self) -> None:
791+
r"""
792+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
793+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
794+
"""
795+
self.use_slicing = True
796+
797+
def disable_slicing(self) -> None:
798+
r"""
799+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
800+
decoding in one step.
801+
"""
802+
self.use_slicing = False
803+
733804
def clear_cache(self):
734805
def _count_conv3d(model):
735806
count = 0
@@ -746,7 +817,7 @@ def _count_conv3d(model):
746817
self._enc_conv_idx = [0]
747818
self._enc_feat_map = [None] * self._enc_conv_num
748819

749-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
820+
def vanilla_encode(self, x: torch.Tensor) -> torch.Tensor:
750821
self.clear_cache()
751822
## cache
752823
t = x.shape[2]
@@ -769,6 +840,12 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
769840
self.clear_cache()
770841
return enc
771842

843+
def _encode(self, x: torch.Tensor):
844+
_, _, _, height, width = x.shape
845+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
846+
return self.tiled_encode(x)
847+
return self.vanilla_encode(x)
848+
772849
@apply_forward_hook
773850
def encode(
774851
self, x: torch.Tensor, return_dict: bool = True
@@ -785,13 +862,18 @@ def encode(
785862
The latent representations of the encoded videos. If `return_dict` is True, a
786863
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787864
"""
788-
h = self._encode(x)
865+
if self.use_slicing and x.shape[0] > 1:
866+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
867+
h = torch.cat(encoded_slices)
868+
else:
869+
h = self._encode(x)
789870
posterior = DiagonalGaussianDistribution(h)
871+
790872
if not return_dict:
791873
return (posterior,)
792874
return AutoencoderKLOutput(latent_dist=posterior)
793875

794-
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
876+
def vanilla_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
795877
self.clear_cache()
796878

797879
iter_ = z.shape[2]
@@ -811,6 +893,15 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
811893

812894
return DecoderOutput(sample=out)
813895

896+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
897+
_, _, _, height, width = z.shape
898+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
899+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
900+
901+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
902+
return self.tiled_decode(z, return_dict=return_dict)
903+
return self.vanilla_decode(z, return_dict=return_dict)
904+
814905
@apply_forward_hook
815906
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
816907
r"""
@@ -826,12 +917,167 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826917
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827918
returned.
828919
"""
829-
decoded = self._decode(z).sample
920+
if self.use_slicing and z.shape[0] > 1:
921+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
922+
decoded = torch.cat(decoded_slices)
923+
else:
924+
decoded = self._decode(z).sample
925+
830926
if not return_dict:
831927
return (decoded,)
832-
833928
return DecoderOutput(sample=decoded)
834929

930+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
931+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
932+
for y in range(blend_extent):
933+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
934+
y / blend_extent
935+
)
936+
return b
937+
938+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
939+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
940+
for x in range(blend_extent):
941+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
942+
x / blend_extent
943+
)
944+
return b
945+
946+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
947+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
948+
for x in range(blend_extent):
949+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
950+
x / blend_extent
951+
)
952+
return b
953+
954+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
955+
r"""Encode a batch of images using a tiled encoder.
956+
957+
Args:
958+
x (`torch.Tensor`): Input batch of videos.
959+
960+
Returns:
961+
`torch.Tensor`:
962+
The latent representation of the encoded videos.
963+
"""
964+
_, _, num_frames, height, width = x.shape
965+
latent_height = height // self.spatial_compression_ratio
966+
latent_width = width // self.spatial_compression_ratio
967+
968+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
969+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
970+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
971+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
972+
973+
blend_height = tile_latent_min_height - tile_latent_stride_height
974+
blend_width = tile_latent_min_width - tile_latent_stride_width
975+
976+
# Split x into overlapping tiles and encode them separately.
977+
# The tiles have an overlap to avoid seams between tiles.
978+
rows = []
979+
for i in range(0, height, self.tile_sample_stride_height):
980+
row = []
981+
for j in range(0, width, self.tile_sample_stride_width):
982+
self.clear_cache()
983+
time = []
984+
frame_range = 1 + (num_frames - 1) // 4
985+
for k in range(frame_range):
986+
self._enc_conv_idx = [0]
987+
if k == 0:
988+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
989+
else:
990+
tile = x[
991+
:,
992+
:,
993+
1 + 4 * (k - 1) : 1 + 4 * k,
994+
i : i + self.tile_sample_min_height,
995+
j : j + self.tile_sample_min_width,
996+
]
997+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
998+
tile = self.quant_conv(tile)
999+
time.append(tile)
1000+
row.append(torch.cat(time, dim=2))
1001+
rows.append(row)
1002+
1003+
result_rows = []
1004+
for i, row in enumerate(rows):
1005+
result_row = []
1006+
for j, tile in enumerate(row):
1007+
# blend the above tile and the left tile
1008+
# to the current tile and add the current tile to the result row
1009+
if i > 0:
1010+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1011+
if j > 0:
1012+
tile = self.blend_h(row[j - 1], tile, blend_width)
1013+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1014+
result_rows.append(torch.cat(result_row, dim=-1))
1015+
1016+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1017+
return enc
1018+
1019+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1020+
r"""
1021+
Decode a batch of images using a tiled decoder.
1022+
1023+
Args:
1024+
z (`torch.Tensor`): Input batch of latent vectors.
1025+
return_dict (`bool`, *optional*, defaults to `True`):
1026+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1027+
1028+
Returns:
1029+
[`~models.vae.DecoderOutput`] or `tuple`:
1030+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1031+
returned.
1032+
"""
1033+
_, _, num_frames, height, width = z.shape
1034+
sample_height = height * self.spatial_compression_ratio
1035+
sample_width = width * self.spatial_compression_ratio
1036+
1037+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1038+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1039+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1040+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1041+
1042+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1043+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1044+
1045+
# Split z into overlapping tiles and decode them separately.
1046+
# The tiles have an overlap to avoid seams between tiles.
1047+
rows = []
1048+
for i in range(0, height, tile_latent_stride_height):
1049+
row = []
1050+
for j in range(0, width, tile_latent_stride_width):
1051+
self.clear_cache()
1052+
time = []
1053+
for k in range(num_frames):
1054+
self._conv_idx = [0]
1055+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1056+
tile = self.post_quant_conv(tile)
1057+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1058+
time.append(decoded)
1059+
row.append(torch.cat(time, dim=2))
1060+
rows.append(row)
1061+
1062+
result_rows = []
1063+
for i, row in enumerate(rows):
1064+
result_row = []
1065+
for j, tile in enumerate(row):
1066+
# blend the above tile and the left tile
1067+
# to the current tile and add the current tile to the result row
1068+
if i > 0:
1069+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1070+
if j > 0:
1071+
tile = self.blend_h(row[j - 1], tile, blend_width)
1072+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1073+
result_rows.append(torch.cat(result_row, dim=-1))
1074+
1075+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1076+
1077+
if not return_dict:
1078+
return (dec,)
1079+
return DecoderOutput(sample=dec)
1080+
8351081
def forward(
8361082
self,
8371083
sample: torch.Tensor,

0 commit comments

Comments
 (0)