Skip to content

[Feature] Implement tiled VAE encoding/decoding for Wan model. #11414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 241 additions & 14 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,76 @@ def __init__(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
)

self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)

# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False

# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False

# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256

# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192

def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.

Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def clear_cache(self):
def _count_conv3d(model):
count = 0
Expand All @@ -746,11 +816,14 @@ def _count_conv3d(model):
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

def _encode(self, x: torch.Tensor) -> torch.Tensor:
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)

self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
Expand All @@ -764,9 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
out = torch.cat([out, out_], 2)

enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a no-op.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find!

self.clear_cache()
return enc

@apply_forward_hook
Expand All @@ -785,18 +855,28 @@ def encode(
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self._encode(x)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)

if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)

def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)

iter_ = z.shape[2]
self.clear_cache()
x = self.post_quant_conv(z)
for i in range(iter_):
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
Expand Down Expand Up @@ -826,12 +906,159 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
decoded = self._decode(z).sample
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample

if not return_dict:
return (decoded,)

return DecoderOutput(sample=decoded)

def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b

def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b

def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.

Args:
x (`torch.Tensor`): Input batch of videos.

Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))

enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))

dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]

if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

def forward(
self,
sample: torch.Tensor,
Expand Down
Loading
Loading