@@ -730,6 +730,77 @@ def __init__(
730
730
base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731
731
)
732
732
733
+ self .temporal_compression_ratio = 2 ** sum (self .temperal_downsample )
734
+ self .spatial_compression_ratio = 2 ** len (self .temperal_downsample )
735
+
736
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
737
+ # to perform decoding of a single video latent at a time.
738
+ self .use_slicing = False
739
+
740
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
741
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
742
+ # intermediate tiles together, the memory requirement can be lowered.
743
+ self .use_tiling = False
744
+
745
+ # The minimal tile height and width for spatial tiling to be used
746
+ self .tile_sample_min_height = 256
747
+ self .tile_sample_min_width = 256
748
+
749
+ # The minimal distance between two spatial tiles
750
+ self .tile_sample_stride_height = 192
751
+ self .tile_sample_stride_width = 192
752
+
753
+ def enable_tiling (
754
+ self ,
755
+ tile_sample_min_height : Optional [int ] = None ,
756
+ tile_sample_min_width : Optional [int ] = None ,
757
+ tile_sample_stride_height : Optional [float ] = None ,
758
+ tile_sample_stride_width : Optional [float ] = None ,
759
+ ) -> None :
760
+ r"""
761
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
762
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
763
+ processing larger images.
764
+
765
+ Args:
766
+ tile_sample_min_height (`int`, *optional*):
767
+ The minimum height required for a sample to be separated into tiles across the height dimension.
768
+ tile_sample_min_width (`int`, *optional*):
769
+ The minimum width required for a sample to be separated into tiles across the width dimension.
770
+ tile_sample_stride_height (`int`, *optional*):
771
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
772
+ no tiling artifacts produced across the height dimension.
773
+ tile_sample_stride_width (`int`, *optional*):
774
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
775
+ artifacts produced across the width dimension.
776
+ """
777
+ self .use_tiling = True
778
+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
779
+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
780
+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
781
+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
782
+
783
+ def disable_tiling (self ) -> None :
784
+ r"""
785
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
786
+ decoding in one step.
787
+ """
788
+ self .use_tiling = False
789
+
790
+ def enable_slicing (self ) -> None :
791
+ r"""
792
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
793
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
794
+ """
795
+ self .use_slicing = True
796
+
797
+ def disable_slicing (self ) -> None :
798
+ r"""
799
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
800
+ decoding in one step.
801
+ """
802
+ self .use_slicing = False
803
+
733
804
def clear_cache (self ):
734
805
def _count_conv3d (model ):
735
806
count = 0
@@ -746,7 +817,7 @@ def _count_conv3d(model):
746
817
self ._enc_conv_idx = [0 ]
747
818
self ._enc_feat_map = [None ] * self ._enc_conv_num
748
819
749
- def _encode (self , x : torch .Tensor ) -> torch .Tensor :
820
+ def vanilla_encode (self , x : torch .Tensor ) -> torch .Tensor :
750
821
self .clear_cache ()
751
822
## cache
752
823
t = x .shape [2 ]
@@ -769,6 +840,12 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
769
840
self .clear_cache ()
770
841
return enc
771
842
843
+ def _encode (self , x : torch .Tensor ):
844
+ _ , _ , _ , height , width = x .shape
845
+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
846
+ return self .tiled_encode (x )
847
+ return self .vanilla_encode (x )
848
+
772
849
@apply_forward_hook
773
850
def encode (
774
851
self , x : torch .Tensor , return_dict : bool = True
@@ -785,13 +862,18 @@ def encode(
785
862
The latent representations of the encoded videos. If `return_dict` is True, a
786
863
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787
864
"""
788
- h = self ._encode (x )
865
+ if self .use_slicing and x .shape [0 ] > 1 :
866
+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
867
+ h = torch .cat (encoded_slices )
868
+ else :
869
+ h = self ._encode (x )
789
870
posterior = DiagonalGaussianDistribution (h )
871
+
790
872
if not return_dict :
791
873
return (posterior ,)
792
874
return AutoencoderKLOutput (latent_dist = posterior )
793
875
794
- def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
876
+ def vanilla_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
795
877
self .clear_cache ()
796
878
797
879
iter_ = z .shape [2 ]
@@ -811,6 +893,15 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
811
893
812
894
return DecoderOutput (sample = out )
813
895
896
+ def _decode (self , z : torch .Tensor , return_dict : bool = True ):
897
+ _ , _ , _ , height , width = z .shape
898
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
899
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
900
+
901
+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
902
+ return self .tiled_decode (z , return_dict = return_dict )
903
+ return self .vanilla_decode (z , return_dict = return_dict )
904
+
814
905
@apply_forward_hook
815
906
def decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
816
907
r"""
@@ -826,12 +917,167 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826
917
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827
918
returned.
828
919
"""
829
- decoded = self ._decode (z ).sample
920
+ if self .use_slicing and z .shape [0 ] > 1 :
921
+ decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
922
+ decoded = torch .cat (decoded_slices )
923
+ else :
924
+ decoded = self ._decode (z ).sample
925
+
830
926
if not return_dict :
831
927
return (decoded ,)
832
-
833
928
return DecoderOutput (sample = decoded )
834
929
930
+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
931
+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
932
+ for y in range (blend_extent ):
933
+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
934
+ y / blend_extent
935
+ )
936
+ return b
937
+
938
+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
939
+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
940
+ for x in range (blend_extent ):
941
+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
942
+ x / blend_extent
943
+ )
944
+ return b
945
+
946
+ def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
947
+ blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
948
+ for x in range (blend_extent ):
949
+ b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
950
+ x / blend_extent
951
+ )
952
+ return b
953
+
954
+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
955
+ r"""Encode a batch of images using a tiled encoder.
956
+
957
+ Args:
958
+ x (`torch.Tensor`): Input batch of videos.
959
+
960
+ Returns:
961
+ `torch.Tensor`:
962
+ The latent representation of the encoded videos.
963
+ """
964
+ _ , _ , num_frames , height , width = x .shape
965
+ latent_height = height // self .spatial_compression_ratio
966
+ latent_width = width // self .spatial_compression_ratio
967
+
968
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
969
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
970
+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
971
+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
972
+
973
+ blend_height = tile_latent_min_height - tile_latent_stride_height
974
+ blend_width = tile_latent_min_width - tile_latent_stride_width
975
+
976
+ # Split x into overlapping tiles and encode them separately.
977
+ # The tiles have an overlap to avoid seams between tiles.
978
+ rows = []
979
+ for i in range (0 , height , self .tile_sample_stride_height ):
980
+ row = []
981
+ for j in range (0 , width , self .tile_sample_stride_width ):
982
+ self .clear_cache ()
983
+ time = []
984
+ frame_range = 1 + (num_frames - 1 ) // 4
985
+ for k in range (frame_range ):
986
+ self ._enc_conv_idx = [0 ]
987
+ if k == 0 :
988
+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
989
+ else :
990
+ tile = x [
991
+ :,
992
+ :,
993
+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
994
+ i : i + self .tile_sample_min_height ,
995
+ j : j + self .tile_sample_min_width ,
996
+ ]
997
+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
998
+ tile = self .quant_conv (tile )
999
+ time .append (tile )
1000
+ row .append (torch .cat (time , dim = 2 ))
1001
+ rows .append (row )
1002
+
1003
+ result_rows = []
1004
+ for i , row in enumerate (rows ):
1005
+ result_row = []
1006
+ for j , tile in enumerate (row ):
1007
+ # blend the above tile and the left tile
1008
+ # to the current tile and add the current tile to the result row
1009
+ if i > 0 :
1010
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1011
+ if j > 0 :
1012
+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1013
+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
1014
+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1015
+
1016
+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
1017
+ return enc
1018
+
1019
+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1020
+ r"""
1021
+ Decode a batch of images using a tiled decoder.
1022
+
1023
+ Args:
1024
+ z (`torch.Tensor`): Input batch of latent vectors.
1025
+ return_dict (`bool`, *optional*, defaults to `True`):
1026
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1027
+
1028
+ Returns:
1029
+ [`~models.vae.DecoderOutput`] or `tuple`:
1030
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1031
+ returned.
1032
+ """
1033
+ _ , _ , num_frames , height , width = z .shape
1034
+ sample_height = height * self .spatial_compression_ratio
1035
+ sample_width = width * self .spatial_compression_ratio
1036
+
1037
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1038
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1039
+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1040
+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1041
+
1042
+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1043
+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
1044
+
1045
+ # Split z into overlapping tiles and decode them separately.
1046
+ # The tiles have an overlap to avoid seams between tiles.
1047
+ rows = []
1048
+ for i in range (0 , height , tile_latent_stride_height ):
1049
+ row = []
1050
+ for j in range (0 , width , tile_latent_stride_width ):
1051
+ self .clear_cache ()
1052
+ time = []
1053
+ for k in range (num_frames ):
1054
+ self ._conv_idx = [0 ]
1055
+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
1056
+ tile = self .post_quant_conv (tile )
1057
+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1058
+ time .append (decoded )
1059
+ row .append (torch .cat (time , dim = 2 ))
1060
+ rows .append (row )
1061
+
1062
+ result_rows = []
1063
+ for i , row in enumerate (rows ):
1064
+ result_row = []
1065
+ for j , tile in enumerate (row ):
1066
+ # blend the above tile and the left tile
1067
+ # to the current tile and add the current tile to the result row
1068
+ if i > 0 :
1069
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1070
+ if j > 0 :
1071
+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1072
+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1073
+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1074
+
1075
+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1076
+
1077
+ if not return_dict :
1078
+ return (dec ,)
1079
+ return DecoderOutput (sample = dec )
1080
+
835
1081
def forward (
836
1082
self ,
837
1083
sample : torch .Tensor ,
0 commit comments