@@ -730,6 +730,76 @@ def __init__(
730
730
base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731
731
)
732
732
733
+ self .spatial_compression_ratio = 2 ** len (self .temperal_downsample )
734
+
735
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
736
+ # to perform decoding of a single video latent at a time.
737
+ self .use_slicing = False
738
+
739
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
740
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
741
+ # intermediate tiles together, the memory requirement can be lowered.
742
+ self .use_tiling = False
743
+
744
+ # The minimal tile height and width for spatial tiling to be used
745
+ self .tile_sample_min_height = 256
746
+ self .tile_sample_min_width = 256
747
+
748
+ # The minimal distance between two spatial tiles
749
+ self .tile_sample_stride_height = 192
750
+ self .tile_sample_stride_width = 192
751
+
752
+ def enable_tiling (
753
+ self ,
754
+ tile_sample_min_height : Optional [int ] = None ,
755
+ tile_sample_min_width : Optional [int ] = None ,
756
+ tile_sample_stride_height : Optional [float ] = None ,
757
+ tile_sample_stride_width : Optional [float ] = None ,
758
+ ) -> None :
759
+ r"""
760
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
761
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
762
+ processing larger images.
763
+
764
+ Args:
765
+ tile_sample_min_height (`int`, *optional*):
766
+ The minimum height required for a sample to be separated into tiles across the height dimension.
767
+ tile_sample_min_width (`int`, *optional*):
768
+ The minimum width required for a sample to be separated into tiles across the width dimension.
769
+ tile_sample_stride_height (`int`, *optional*):
770
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
771
+ no tiling artifacts produced across the height dimension.
772
+ tile_sample_stride_width (`int`, *optional*):
773
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
774
+ artifacts produced across the width dimension.
775
+ """
776
+ self .use_tiling = True
777
+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
778
+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
779
+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
780
+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
781
+
782
+ def disable_tiling (self ) -> None :
783
+ r"""
784
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
785
+ decoding in one step.
786
+ """
787
+ self .use_tiling = False
788
+
789
+ def enable_slicing (self ) -> None :
790
+ r"""
791
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
792
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
793
+ """
794
+ self .use_slicing = True
795
+
796
+ def disable_slicing (self ) -> None :
797
+ r"""
798
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
799
+ decoding in one step.
800
+ """
801
+ self .use_slicing = False
802
+
733
803
def clear_cache (self ):
734
804
def _count_conv3d (model ):
735
805
count = 0
@@ -746,11 +816,14 @@ def _count_conv3d(model):
746
816
self ._enc_conv_idx = [0 ]
747
817
self ._enc_feat_map = [None ] * self ._enc_conv_num
748
818
749
- def _encode (self , x : torch .Tensor ) -> torch .Tensor :
819
+ def _encode (self , x : torch .Tensor ):
820
+ _ , _ , num_frame , height , width = x .shape
821
+
822
+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
823
+ return self .tiled_encode (x )
824
+
750
825
self .clear_cache ()
751
- ## cache
752
- t = x .shape [2 ]
753
- iter_ = 1 + (t - 1 ) // 4
826
+ iter_ = 1 + (num_frame - 1 ) // 4
754
827
for i in range (iter_ ):
755
828
self ._enc_conv_idx = [0 ]
756
829
if i == 0 :
@@ -764,9 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
764
837
out = torch .cat ([out , out_ ], 2 )
765
838
766
839
enc = self .quant_conv (out )
767
- mu , logvar = enc [:, : self .z_dim , :, :, :], enc [:, self .z_dim :, :, :, :]
768
- enc = torch .cat ([mu , logvar ], dim = 1 )
769
- self .clear_cache ()
770
840
return enc
771
841
772
842
@apply_forward_hook
@@ -785,18 +855,28 @@ def encode(
785
855
The latent representations of the encoded videos. If `return_dict` is True, a
786
856
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787
857
"""
788
- h = self ._encode (x )
858
+ if self .use_slicing and x .shape [0 ] > 1 :
859
+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
860
+ h = torch .cat (encoded_slices )
861
+ else :
862
+ h = self ._encode (x )
789
863
posterior = DiagonalGaussianDistribution (h )
864
+
790
865
if not return_dict :
791
866
return (posterior ,)
792
867
return AutoencoderKLOutput (latent_dist = posterior )
793
868
794
- def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
795
- self .clear_cache ()
869
+ def _decode (self , z : torch .Tensor , return_dict : bool = True ):
870
+ _ , _ , num_frame , height , width = z .shape
871
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
872
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
873
+
874
+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
875
+ return self .tiled_decode (z , return_dict = return_dict )
796
876
797
- iter_ = z . shape [ 2 ]
877
+ self . clear_cache ()
798
878
x = self .post_quant_conv (z )
799
- for i in range (iter_ ):
879
+ for i in range (num_frame ):
800
880
self ._conv_idx = [0 ]
801
881
if i == 0 :
802
882
out = self .decoder (x [:, :, i : i + 1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
@@ -826,12 +906,159 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826
906
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827
907
returned.
828
908
"""
829
- decoded = self ._decode (z ).sample
909
+ if self .use_slicing and z .shape [0 ] > 1 :
910
+ decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
911
+ decoded = torch .cat (decoded_slices )
912
+ else :
913
+ decoded = self ._decode (z ).sample
914
+
830
915
if not return_dict :
831
916
return (decoded ,)
832
-
833
917
return DecoderOutput (sample = decoded )
834
918
919
+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
920
+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
921
+ for y in range (blend_extent ):
922
+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
923
+ y / blend_extent
924
+ )
925
+ return b
926
+
927
+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
928
+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
929
+ for x in range (blend_extent ):
930
+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
931
+ x / blend_extent
932
+ )
933
+ return b
934
+
935
+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
936
+ r"""Encode a batch of images using a tiled encoder.
937
+
938
+ Args:
939
+ x (`torch.Tensor`): Input batch of videos.
940
+
941
+ Returns:
942
+ `torch.Tensor`:
943
+ The latent representation of the encoded videos.
944
+ """
945
+ _ , _ , num_frames , height , width = x .shape
946
+ latent_height = height // self .spatial_compression_ratio
947
+ latent_width = width // self .spatial_compression_ratio
948
+
949
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
950
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
951
+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
952
+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
953
+
954
+ blend_height = tile_latent_min_height - tile_latent_stride_height
955
+ blend_width = tile_latent_min_width - tile_latent_stride_width
956
+
957
+ # Split x into overlapping tiles and encode them separately.
958
+ # The tiles have an overlap to avoid seams between tiles.
959
+ rows = []
960
+ for i in range (0 , height , self .tile_sample_stride_height ):
961
+ row = []
962
+ for j in range (0 , width , self .tile_sample_stride_width ):
963
+ self .clear_cache ()
964
+ time = []
965
+ frame_range = 1 + (num_frames - 1 ) // 4
966
+ for k in range (frame_range ):
967
+ self ._enc_conv_idx = [0 ]
968
+ if k == 0 :
969
+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
970
+ else :
971
+ tile = x [
972
+ :,
973
+ :,
974
+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
975
+ i : i + self .tile_sample_min_height ,
976
+ j : j + self .tile_sample_min_width ,
977
+ ]
978
+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
979
+ tile = self .quant_conv (tile )
980
+ time .append (tile )
981
+ row .append (torch .cat (time , dim = 2 ))
982
+ rows .append (row )
983
+
984
+ result_rows = []
985
+ for i , row in enumerate (rows ):
986
+ result_row = []
987
+ for j , tile in enumerate (row ):
988
+ # blend the above tile and the left tile
989
+ # to the current tile and add the current tile to the result row
990
+ if i > 0 :
991
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
992
+ if j > 0 :
993
+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
994
+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
995
+ result_rows .append (torch .cat (result_row , dim = - 1 ))
996
+
997
+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
998
+ return enc
999
+
1000
+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1001
+ r"""
1002
+ Decode a batch of images using a tiled decoder.
1003
+
1004
+ Args:
1005
+ z (`torch.Tensor`): Input batch of latent vectors.
1006
+ return_dict (`bool`, *optional*, defaults to `True`):
1007
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1008
+
1009
+ Returns:
1010
+ [`~models.vae.DecoderOutput`] or `tuple`:
1011
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1012
+ returned.
1013
+ """
1014
+ _ , _ , num_frames , height , width = z .shape
1015
+ sample_height = height * self .spatial_compression_ratio
1016
+ sample_width = width * self .spatial_compression_ratio
1017
+
1018
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1019
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1020
+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1021
+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1022
+
1023
+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1024
+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
1025
+
1026
+ # Split z into overlapping tiles and decode them separately.
1027
+ # The tiles have an overlap to avoid seams between tiles.
1028
+ rows = []
1029
+ for i in range (0 , height , tile_latent_stride_height ):
1030
+ row = []
1031
+ for j in range (0 , width , tile_latent_stride_width ):
1032
+ self .clear_cache ()
1033
+ time = []
1034
+ for k in range (num_frames ):
1035
+ self ._conv_idx = [0 ]
1036
+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
1037
+ tile = self .post_quant_conv (tile )
1038
+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1039
+ time .append (decoded )
1040
+ row .append (torch .cat (time , dim = 2 ))
1041
+ rows .append (row )
1042
+
1043
+ result_rows = []
1044
+ for i , row in enumerate (rows ):
1045
+ result_row = []
1046
+ for j , tile in enumerate (row ):
1047
+ # blend the above tile and the left tile
1048
+ # to the current tile and add the current tile to the result row
1049
+ if i > 0 :
1050
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1051
+ if j > 0 :
1052
+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1053
+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1054
+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1055
+
1056
+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1057
+
1058
+ if not return_dict :
1059
+ return (dec ,)
1060
+ return DecoderOutput (sample = dec )
1061
+
835
1062
def forward (
836
1063
self ,
837
1064
sample : torch .Tensor ,
0 commit comments