@@ -341,6 +341,7 @@ def forward(
341
341
block_controlnet_hidden_states : List = None ,
342
342
joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
343
343
return_dict : bool = True ,
344
+ skip_layers : Optional [List [int ]] = None ,
344
345
) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
345
346
"""
346
347
The [`SD3Transformer2DModel`] forward method.
@@ -363,6 +364,8 @@ def forward(
363
364
return_dict (`bool`, *optional*, defaults to `True`):
364
365
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
365
366
tuple.
367
+ skip_layers (`list` of `int`, *optional*):
368
+ A list of layer indices to skip during the forward pass.
366
369
367
370
Returns:
368
371
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -390,7 +393,10 @@ def forward(
390
393
encoder_hidden_states = self .context_embedder (encoder_hidden_states )
391
394
392
395
for index_block , block in enumerate (self .transformer_blocks ):
393
- if self .training and self .gradient_checkpointing :
396
+ # Skip specified layers
397
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
398
+
399
+ if torch .is_grad_enabled () and self .gradient_checkpointing and not is_skip :
394
400
395
401
def create_custom_forward (module , return_dict = None ):
396
402
def custom_forward (* inputs ):
@@ -410,8 +416,7 @@ def custom_forward(*inputs):
410
416
joint_attention_kwargs ,
411
417
** ckpt_kwargs ,
412
418
)
413
-
414
- else :
419
+ elif not is_skip :
415
420
encoder_hidden_states , hidden_states = block (
416
421
hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb ,
417
422
joint_attention_kwargs = joint_attention_kwargs ,
0 commit comments