-
Notifications
You must be signed in to change notification settings - Fork 6k
[Feature] Implement tiled VAE encoding/decoding for Wan model. #11414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -730,6 +730,76 @@ def __init__( | |
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout | ||
) | ||
|
||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) | ||
|
||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension | ||
# to perform decoding of a single video latent at a time. | ||
self.use_slicing = False | ||
|
||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent | ||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the | ||
# intermediate tiles together, the memory requirement can be lowered. | ||
self.use_tiling = False | ||
|
||
# The minimal tile height and width for spatial tiling to be used | ||
self.tile_sample_min_height = 256 | ||
self.tile_sample_min_width = 256 | ||
|
||
# The minimal distance between two spatial tiles | ||
self.tile_sample_stride_height = 192 | ||
self.tile_sample_stride_width = 192 | ||
|
||
def enable_tiling( | ||
self, | ||
tile_sample_min_height: Optional[int] = None, | ||
tile_sample_min_width: Optional[int] = None, | ||
tile_sample_stride_height: Optional[float] = None, | ||
tile_sample_stride_width: Optional[float] = None, | ||
) -> None: | ||
r""" | ||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | ||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | ||
processing larger images. | ||
|
||
Args: | ||
tile_sample_min_height (`int`, *optional*): | ||
The minimum height required for a sample to be separated into tiles across the height dimension. | ||
tile_sample_min_width (`int`, *optional*): | ||
The minimum width required for a sample to be separated into tiles across the width dimension. | ||
tile_sample_stride_height (`int`, *optional*): | ||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are | ||
no tiling artifacts produced across the height dimension. | ||
tile_sample_stride_width (`int`, *optional*): | ||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling | ||
artifacts produced across the width dimension. | ||
""" | ||
self.use_tiling = True | ||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height | ||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width | ||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height | ||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width | ||
|
||
def disable_tiling(self) -> None: | ||
r""" | ||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing | ||
decoding in one step. | ||
""" | ||
self.use_tiling = False | ||
|
||
def enable_slicing(self) -> None: | ||
r""" | ||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | ||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | ||
""" | ||
self.use_slicing = True | ||
|
||
def disable_slicing(self) -> None: | ||
r""" | ||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing | ||
decoding in one step. | ||
""" | ||
self.use_slicing = False | ||
|
||
def clear_cache(self): | ||
def _count_conv3d(model): | ||
count = 0 | ||
|
@@ -746,11 +816,14 @@ def _count_conv3d(model): | |
self._enc_conv_idx = [0] | ||
self._enc_feat_map = [None] * self._enc_conv_num | ||
|
||
def _encode(self, x: torch.Tensor) -> torch.Tensor: | ||
def _encode(self, x: torch.Tensor): | ||
_, _, num_frame, height, width = x.shape | ||
|
||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): | ||
return self.tiled_encode(x) | ||
|
||
self.clear_cache() | ||
## cache | ||
t = x.shape[2] | ||
iter_ = 1 + (t - 1) // 4 | ||
iter_ = 1 + (num_frame - 1) // 4 | ||
for i in range(iter_): | ||
self._enc_conv_idx = [0] | ||
if i == 0: | ||
|
@@ -764,9 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: | |
out = torch.cat([out, out_], 2) | ||
|
||
enc = self.quant_conv(out) | ||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] | ||
enc = torch.cat([mu, logvar], dim=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a no-op. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice find! |
||
self.clear_cache() | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return enc | ||
|
||
@apply_forward_hook | ||
|
@@ -785,18 +855,28 @@ def encode( | |
The latent representations of the encoded videos. If `return_dict` is True, a | ||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. | ||
""" | ||
h = self._encode(x) | ||
if self.use_slicing and x.shape[0] > 1: | ||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] | ||
h = torch.cat(encoded_slices) | ||
else: | ||
h = self._encode(x) | ||
posterior = DiagonalGaussianDistribution(h) | ||
|
||
if not return_dict: | ||
return (posterior,) | ||
return AutoencoderKLOutput(latent_dist=posterior) | ||
|
||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: | ||
self.clear_cache() | ||
def _decode(self, z: torch.Tensor, return_dict: bool = True): | ||
_, _, num_frame, height, width = z.shape | ||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
|
||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): | ||
return self.tiled_decode(z, return_dict=return_dict) | ||
|
||
iter_ = z.shape[2] | ||
self.clear_cache() | ||
x = self.post_quant_conv(z) | ||
for i in range(iter_): | ||
for i in range(num_frame): | ||
self._conv_idx = [0] | ||
if i == 0: | ||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) | ||
|
@@ -826,12 +906,159 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp | |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is | ||
returned. | ||
""" | ||
decoded = self._decode(z).sample | ||
if self.use_slicing and z.shape[0] > 1: | ||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] | ||
decoded = torch.cat(decoded_slices) | ||
else: | ||
decoded = self._decode(z).sample | ||
|
||
if not return_dict: | ||
return (decoded,) | ||
|
||
return DecoderOutput(sample=decoded) | ||
|
||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) | ||
for y in range(blend_extent): | ||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( | ||
y / blend_extent | ||
) | ||
return b | ||
|
||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) | ||
for x in range(blend_extent): | ||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( | ||
x / blend_extent | ||
) | ||
return b | ||
|
||
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: | ||
r"""Encode a batch of images using a tiled encoder. | ||
|
||
Args: | ||
x (`torch.Tensor`): Input batch of videos. | ||
|
||
Returns: | ||
`torch.Tensor`: | ||
The latent representation of the encoded videos. | ||
""" | ||
_, _, num_frames, height, width = x.shape | ||
latent_height = height // self.spatial_compression_ratio | ||
latent_width = width // self.spatial_compression_ratio | ||
|
||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio | ||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio | ||
|
||
blend_height = tile_latent_min_height - tile_latent_stride_height | ||
blend_width = tile_latent_min_width - tile_latent_stride_width | ||
|
||
# Split x into overlapping tiles and encode them separately. | ||
# The tiles have an overlap to avoid seams between tiles. | ||
rows = [] | ||
for i in range(0, height, self.tile_sample_stride_height): | ||
row = [] | ||
for j in range(0, width, self.tile_sample_stride_width): | ||
self.clear_cache() | ||
time = [] | ||
frame_range = 1 + (num_frames - 1) // 4 | ||
for k in range(frame_range): | ||
self._enc_conv_idx = [0] | ||
if k == 0: | ||
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] | ||
else: | ||
tile = x[ | ||
:, | ||
:, | ||
1 + 4 * (k - 1) : 1 + 4 * k, | ||
i : i + self.tile_sample_min_height, | ||
j : j + self.tile_sample_min_width, | ||
] | ||
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) | ||
tile = self.quant_conv(tile) | ||
time.append(tile) | ||
row.append(torch.cat(time, dim=2)) | ||
rows.append(row) | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_height) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_width) | ||
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) | ||
result_rows.append(torch.cat(result_row, dim=-1)) | ||
|
||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] | ||
return enc | ||
|
||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: | ||
r""" | ||
Decode a batch of images using a tiled decoder. | ||
|
||
Args: | ||
z (`torch.Tensor`): Input batch of latent vectors. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. | ||
|
||
Returns: | ||
[`~models.vae.DecoderOutput`] or `tuple`: | ||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is | ||
returned. | ||
""" | ||
_, _, num_frames, height, width = z.shape | ||
sample_height = height * self.spatial_compression_ratio | ||
sample_width = width * self.spatial_compression_ratio | ||
|
||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio | ||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio | ||
|
||
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height | ||
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width | ||
|
||
# Split z into overlapping tiles and decode them separately. | ||
# The tiles have an overlap to avoid seams between tiles. | ||
rows = [] | ||
for i in range(0, height, tile_latent_stride_height): | ||
row = [] | ||
for j in range(0, width, tile_latent_stride_width): | ||
self.clear_cache() | ||
time = [] | ||
for k in range(num_frames): | ||
self._conv_idx = [0] | ||
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] | ||
tile = self.post_quant_conv(tile) | ||
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) | ||
time.append(decoded) | ||
row.append(torch.cat(time, dim=2)) | ||
rows.append(row) | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_height) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_width) | ||
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) | ||
result_rows.append(torch.cat(result_row, dim=-1)) | ||
|
||
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] | ||
|
||
if not return_dict: | ||
return (dec,) | ||
return DecoderOutput(sample=dec) | ||
|
||
def forward( | ||
self, | ||
sample: torch.Tensor, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.