@@ -187,20 +187,20 @@ def forward(
187
187
hidden_states = self .norm (hidden_states )
188
188
hidden_states = hidden_states .permute (0 , 3 , 4 , 2 , 1 ).reshape (batch_size * height * width , num_frames , channel )
189
189
190
- hidden_states = self .proj_in (hidden_states )
190
+ hidden_states = self .proj_in (input = hidden_states )
191
191
192
192
# 2. Blocks
193
193
for block in self .transformer_blocks :
194
194
hidden_states = block (
195
- hidden_states ,
195
+ hidden_states = hidden_states ,
196
196
encoder_hidden_states = encoder_hidden_states ,
197
197
timestep = timestep ,
198
198
cross_attention_kwargs = cross_attention_kwargs ,
199
199
class_labels = class_labels ,
200
200
)
201
201
202
202
# 3. Output
203
- hidden_states = self .proj_out (hidden_states )
203
+ hidden_states = self .proj_out (input = hidden_states )
204
204
hidden_states = (
205
205
hidden_states [None , None , :]
206
206
.reshape (batch_size , height , width , num_frames , channel )
@@ -344,15 +344,15 @@ def custom_forward(*inputs):
344
344
)
345
345
346
346
else :
347
- hidden_states = resnet (hidden_states , temb )
347
+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
348
348
349
349
hidden_states = motion_module (hidden_states , num_frames = num_frames )
350
350
351
351
output_states = output_states + (hidden_states ,)
352
352
353
353
if self .downsamplers is not None :
354
354
for downsampler in self .downsamplers :
355
- hidden_states = downsampler (hidden_states )
355
+ hidden_states = downsampler (hidden_states = hidden_states )
356
356
357
357
output_states = output_states + (hidden_states ,)
358
358
@@ -531,25 +531,18 @@ def custom_forward(*inputs):
531
531
temb ,
532
532
** ckpt_kwargs ,
533
533
)
534
- hidden_states = attn (
535
- hidden_states ,
536
- encoder_hidden_states = encoder_hidden_states ,
537
- cross_attention_kwargs = cross_attention_kwargs ,
538
- attention_mask = attention_mask ,
539
- encoder_attention_mask = encoder_attention_mask ,
540
- return_dict = False ,
541
- )[0 ]
542
534
else :
543
- hidden_states = resnet (hidden_states , temb )
535
+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
536
+
537
+ hidden_states = attn (
538
+ hidden_states = hidden_states ,
539
+ encoder_hidden_states = encoder_hidden_states ,
540
+ cross_attention_kwargs = cross_attention_kwargs ,
541
+ attention_mask = attention_mask ,
542
+ encoder_attention_mask = encoder_attention_mask ,
543
+ return_dict = False ,
544
+ )[0 ]
544
545
545
- hidden_states = attn (
546
- hidden_states ,
547
- encoder_hidden_states = encoder_hidden_states ,
548
- cross_attention_kwargs = cross_attention_kwargs ,
549
- attention_mask = attention_mask ,
550
- encoder_attention_mask = encoder_attention_mask ,
551
- return_dict = False ,
552
- )[0 ]
553
546
hidden_states = motion_module (
554
547
hidden_states ,
555
548
num_frames = num_frames ,
@@ -563,7 +556,7 @@ def custom_forward(*inputs):
563
556
564
557
if self .downsamplers is not None :
565
558
for downsampler in self .downsamplers :
566
- hidden_states = downsampler (hidden_states )
559
+ hidden_states = downsampler (hidden_states = hidden_states )
567
560
568
561
output_states = output_states + (hidden_states ,)
569
562
@@ -757,33 +750,26 @@ def custom_forward(*inputs):
757
750
temb ,
758
751
** ckpt_kwargs ,
759
752
)
760
- hidden_states = attn (
761
- hidden_states ,
762
- encoder_hidden_states = encoder_hidden_states ,
763
- cross_attention_kwargs = cross_attention_kwargs ,
764
- attention_mask = attention_mask ,
765
- encoder_attention_mask = encoder_attention_mask ,
766
- return_dict = False ,
767
- )[0 ]
768
753
else :
769
- hidden_states = resnet (hidden_states , temb )
754
+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
755
+
756
+ hidden_states = attn (
757
+ hidden_states = hidden_states ,
758
+ encoder_hidden_states = encoder_hidden_states ,
759
+ cross_attention_kwargs = cross_attention_kwargs ,
760
+ attention_mask = attention_mask ,
761
+ encoder_attention_mask = encoder_attention_mask ,
762
+ return_dict = False ,
763
+ )[0 ]
770
764
771
- hidden_states = attn (
772
- hidden_states ,
773
- encoder_hidden_states = encoder_hidden_states ,
774
- cross_attention_kwargs = cross_attention_kwargs ,
775
- attention_mask = attention_mask ,
776
- encoder_attention_mask = encoder_attention_mask ,
777
- return_dict = False ,
778
- )[0 ]
779
765
hidden_states = motion_module (
780
766
hidden_states ,
781
767
num_frames = num_frames ,
782
768
)
783
769
784
770
if self .upsamplers is not None :
785
771
for upsampler in self .upsamplers :
786
- hidden_states = upsampler (hidden_states , upsample_size )
772
+ hidden_states = upsampler (hidden_states = hidden_states , output_size = upsample_size )
787
773
788
774
return hidden_states
789
775
@@ -929,13 +915,13 @@ def custom_forward(*inputs):
929
915
create_custom_forward (resnet ), hidden_states , temb
930
916
)
931
917
else :
932
- hidden_states = resnet (hidden_states , temb )
918
+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
933
919
934
920
hidden_states = motion_module (hidden_states , num_frames = num_frames )
935
921
936
922
if self .upsamplers is not None :
937
923
for upsampler in self .upsamplers :
938
- hidden_states = upsampler (hidden_states , upsample_size )
924
+ hidden_states = upsampler (hidden_states = hidden_states , output_size = upsample_size )
939
925
940
926
return hidden_states
941
927
@@ -1080,10 +1066,19 @@ def forward(
1080
1066
if cross_attention_kwargs .get ("scale" , None ) is not None :
1081
1067
logger .warning ("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored." )
1082
1068
1083
- hidden_states = self .resnets [0 ](hidden_states , temb )
1069
+ hidden_states = self .resnets [0 ](input_tensor = hidden_states , temb = temb )
1084
1070
1085
1071
blocks = zip (self .attentions , self .resnets [1 :], self .motion_modules )
1086
1072
for attn , resnet , motion_module in blocks :
1073
+ hidden_states = attn (
1074
+ hidden_states = hidden_states ,
1075
+ encoder_hidden_states = encoder_hidden_states ,
1076
+ cross_attention_kwargs = cross_attention_kwargs ,
1077
+ attention_mask = attention_mask ,
1078
+ encoder_attention_mask = encoder_attention_mask ,
1079
+ return_dict = False ,
1080
+ )[0 ]
1081
+
1087
1082
if self .training and self .gradient_checkpointing :
1088
1083
1089
1084
def create_custom_forward (module , return_dict = None ):
@@ -1096,14 +1091,6 @@ def custom_forward(*inputs):
1096
1091
return custom_forward
1097
1092
1098
1093
ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
1099
- hidden_states = attn (
1100
- hidden_states ,
1101
- encoder_hidden_states = encoder_hidden_states ,
1102
- cross_attention_kwargs = cross_attention_kwargs ,
1103
- attention_mask = attention_mask ,
1104
- encoder_attention_mask = encoder_attention_mask ,
1105
- return_dict = False ,
1106
- )[0 ]
1107
1094
hidden_states = torch .utils .checkpoint .checkpoint (
1108
1095
create_custom_forward (motion_module ),
1109
1096
hidden_states ,
@@ -1117,19 +1104,11 @@ def custom_forward(*inputs):
1117
1104
** ckpt_kwargs ,
1118
1105
)
1119
1106
else :
1120
- hidden_states = attn (
1121
- hidden_states ,
1122
- encoder_hidden_states = encoder_hidden_states ,
1123
- cross_attention_kwargs = cross_attention_kwargs ,
1124
- attention_mask = attention_mask ,
1125
- encoder_attention_mask = encoder_attention_mask ,
1126
- return_dict = False ,
1127
- )[0 ]
1128
1107
hidden_states = motion_module (
1129
1108
hidden_states ,
1130
1109
num_frames = num_frames ,
1131
1110
)
1132
- hidden_states = resnet (hidden_states , temb )
1111
+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
1133
1112
1134
1113
return hidden_states
1135
1114
0 commit comments