Skip to content

Commit 780cf00

Browse files
committed
implement tiled encode/decode
1 parent edd7880 commit 780cf00

File tree

2 files changed

+318
-15
lines changed

2 files changed

+318
-15
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 241 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,76 @@ def __init__(
730730
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
731731
)
732732

733+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
734+
735+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
736+
# to perform decoding of a single video latent at a time.
737+
self.use_slicing = False
738+
739+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
740+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
741+
# intermediate tiles together, the memory requirement can be lowered.
742+
self.use_tiling = False
743+
744+
# The minimal tile height and width for spatial tiling to be used
745+
self.tile_sample_min_height = 256
746+
self.tile_sample_min_width = 256
747+
748+
# The minimal distance between two spatial tiles
749+
self.tile_sample_stride_height = 192
750+
self.tile_sample_stride_width = 192
751+
752+
def enable_tiling(
753+
self,
754+
tile_sample_min_height: Optional[int] = None,
755+
tile_sample_min_width: Optional[int] = None,
756+
tile_sample_stride_height: Optional[float] = None,
757+
tile_sample_stride_width: Optional[float] = None,
758+
) -> None:
759+
r"""
760+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
761+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
762+
processing larger images.
763+
764+
Args:
765+
tile_sample_min_height (`int`, *optional*):
766+
The minimum height required for a sample to be separated into tiles across the height dimension.
767+
tile_sample_min_width (`int`, *optional*):
768+
The minimum width required for a sample to be separated into tiles across the width dimension.
769+
tile_sample_stride_height (`int`, *optional*):
770+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
771+
no tiling artifacts produced across the height dimension.
772+
tile_sample_stride_width (`int`, *optional*):
773+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
774+
artifacts produced across the width dimension.
775+
"""
776+
self.use_tiling = True
777+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
778+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
779+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
780+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
781+
782+
def disable_tiling(self) -> None:
783+
r"""
784+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
785+
decoding in one step.
786+
"""
787+
self.use_tiling = False
788+
789+
def enable_slicing(self) -> None:
790+
r"""
791+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
792+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
793+
"""
794+
self.use_slicing = True
795+
796+
def disable_slicing(self) -> None:
797+
r"""
798+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
799+
decoding in one step.
800+
"""
801+
self.use_slicing = False
802+
733803
def clear_cache(self):
734804
def _count_conv3d(model):
735805
count = 0
@@ -746,11 +816,14 @@ def _count_conv3d(model):
746816
self._enc_conv_idx = [0]
747817
self._enc_feat_map = [None] * self._enc_conv_num
748818

749-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
819+
def _encode(self, x: torch.Tensor):
820+
_, _, num_frame, height, width = x.shape
821+
822+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
823+
return self.tiled_encode(x)
824+
750825
self.clear_cache()
751-
## cache
752-
t = x.shape[2]
753-
iter_ = 1 + (t - 1) // 4
826+
iter_ = 1 + (num_frame - 1) // 4
754827
for i in range(iter_):
755828
self._enc_conv_idx = [0]
756829
if i == 0:
@@ -764,9 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
764837
out = torch.cat([out, out_], 2)
765838

766839
enc = self.quant_conv(out)
767-
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
768-
enc = torch.cat([mu, logvar], dim=1)
769-
self.clear_cache()
770840
return enc
771841

772842
@apply_forward_hook
@@ -785,18 +855,28 @@ def encode(
785855
The latent representations of the encoded videos. If `return_dict` is True, a
786856
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787857
"""
788-
h = self._encode(x)
858+
if self.use_slicing and x.shape[0] > 1:
859+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
860+
h = torch.cat(encoded_slices)
861+
else:
862+
h = self._encode(x)
789863
posterior = DiagonalGaussianDistribution(h)
864+
790865
if not return_dict:
791866
return (posterior,)
792867
return AutoencoderKLOutput(latent_dist=posterior)
793868

794-
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
795-
self.clear_cache()
869+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
870+
_, _, num_frame, height, width = z.shape
871+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
872+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
873+
874+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
875+
return self.tiled_decode(z, return_dict=return_dict)
796876

797-
iter_ = z.shape[2]
877+
self.clear_cache()
798878
x = self.post_quant_conv(z)
799-
for i in range(iter_):
879+
for i in range(num_frame):
800880
self._conv_idx = [0]
801881
if i == 0:
802882
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
@@ -826,12 +906,159 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826906
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827907
returned.
828908
"""
829-
decoded = self._decode(z).sample
909+
if self.use_slicing and z.shape[0] > 1:
910+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
911+
decoded = torch.cat(decoded_slices)
912+
else:
913+
decoded = self._decode(z).sample
914+
830915
if not return_dict:
831916
return (decoded,)
832-
833917
return DecoderOutput(sample=decoded)
834918

919+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
920+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
921+
for y in range(blend_extent):
922+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
923+
y / blend_extent
924+
)
925+
return b
926+
927+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
928+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
929+
for x in range(blend_extent):
930+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
931+
x / blend_extent
932+
)
933+
return b
934+
935+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
936+
r"""Encode a batch of images using a tiled encoder.
937+
938+
Args:
939+
x (`torch.Tensor`): Input batch of videos.
940+
941+
Returns:
942+
`torch.Tensor`:
943+
The latent representation of the encoded videos.
944+
"""
945+
_, _, num_frames, height, width = x.shape
946+
latent_height = height // self.spatial_compression_ratio
947+
latent_width = width // self.spatial_compression_ratio
948+
949+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
950+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
951+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
952+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
953+
954+
blend_height = tile_latent_min_height - tile_latent_stride_height
955+
blend_width = tile_latent_min_width - tile_latent_stride_width
956+
957+
# Split x into overlapping tiles and encode them separately.
958+
# The tiles have an overlap to avoid seams between tiles.
959+
rows = []
960+
for i in range(0, height, self.tile_sample_stride_height):
961+
row = []
962+
for j in range(0, width, self.tile_sample_stride_width):
963+
self.clear_cache()
964+
time = []
965+
frame_range = 1 + (num_frames - 1) // 4
966+
for k in range(frame_range):
967+
self._enc_conv_idx = [0]
968+
if k == 0:
969+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
970+
else:
971+
tile = x[
972+
:,
973+
:,
974+
1 + 4 * (k - 1) : 1 + 4 * k,
975+
i : i + self.tile_sample_min_height,
976+
j : j + self.tile_sample_min_width,
977+
]
978+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
979+
tile = self.quant_conv(tile)
980+
time.append(tile)
981+
row.append(torch.cat(time, dim=2))
982+
rows.append(row)
983+
984+
result_rows = []
985+
for i, row in enumerate(rows):
986+
result_row = []
987+
for j, tile in enumerate(row):
988+
# blend the above tile and the left tile
989+
# to the current tile and add the current tile to the result row
990+
if i > 0:
991+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
992+
if j > 0:
993+
tile = self.blend_h(row[j - 1], tile, blend_width)
994+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
995+
result_rows.append(torch.cat(result_row, dim=-1))
996+
997+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
998+
return enc
999+
1000+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1001+
r"""
1002+
Decode a batch of images using a tiled decoder.
1003+
1004+
Args:
1005+
z (`torch.Tensor`): Input batch of latent vectors.
1006+
return_dict (`bool`, *optional*, defaults to `True`):
1007+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1008+
1009+
Returns:
1010+
[`~models.vae.DecoderOutput`] or `tuple`:
1011+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1012+
returned.
1013+
"""
1014+
_, _, num_frames, height, width = z.shape
1015+
sample_height = height * self.spatial_compression_ratio
1016+
sample_width = width * self.spatial_compression_ratio
1017+
1018+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1019+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1020+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1021+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1022+
1023+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1024+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1025+
1026+
# Split z into overlapping tiles and decode them separately.
1027+
# The tiles have an overlap to avoid seams between tiles.
1028+
rows = []
1029+
for i in range(0, height, tile_latent_stride_height):
1030+
row = []
1031+
for j in range(0, width, tile_latent_stride_width):
1032+
self.clear_cache()
1033+
time = []
1034+
for k in range(num_frames):
1035+
self._conv_idx = [0]
1036+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1037+
tile = self.post_quant_conv(tile)
1038+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1039+
time.append(decoded)
1040+
row.append(torch.cat(time, dim=2))
1041+
rows.append(row)
1042+
1043+
result_rows = []
1044+
for i, row in enumerate(rows):
1045+
result_row = []
1046+
for j, tile in enumerate(row):
1047+
# blend the above tile and the left tile
1048+
# to the current tile and add the current tile to the result row
1049+
if i > 0:
1050+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1051+
if j > 0:
1052+
tile = self.blend_h(row[j - 1], tile, blend_width)
1053+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1054+
result_rows.append(torch.cat(result_row, dim=-1))
1055+
1056+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1057+
1058+
if not return_dict:
1059+
return (dec,)
1060+
return DecoderOutput(sample=dec)
1061+
8351062
def forward(
8361063
self,
8371064
sample: torch.Tensor,

0 commit comments

Comments
 (0)