Skip to content

Commit 30e0dda

Browse files
committed
Fixed typo and reverted removal of skip_layers in SD3Transformer2DModel
1 parent 50d09d9 commit 30e0dda

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def forward(
341341
block_controlnet_hidden_states: List = None,
342342
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
343343
return_dict: bool = True,
344+
skip_layers: Optional[List[int]] = None,
344345
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
345346
"""
346347
The [`SD3Transformer2DModel`] forward method.
@@ -363,6 +364,8 @@ def forward(
363364
return_dict (`bool`, *optional*, defaults to `True`):
364365
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
365366
tuple.
367+
skip_layers (`list` of `int`, *optional*):
368+
A list of layer indices to skip during the forward pass.
366369
367370
Returns:
368371
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -390,7 +393,10 @@ def forward(
390393
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
391394

392395
for index_block, block in enumerate(self.transformer_blocks):
393-
if self.training and self.gradient_checkpointing:
396+
# Skip specified layers
397+
is_skip = True if skip_layers is not None and index_block in skip_layers else False
398+
399+
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
394400

395401
def create_custom_forward(module, return_dict=None):
396402
def custom_forward(*inputs):
@@ -410,8 +416,7 @@ def custom_forward(*inputs):
410416
joint_attention_kwargs,
411417
**ckpt_kwargs,
412418
)
413-
414-
else:
419+
elif not is_skip:
415420
encoder_hidden_states, hidden_states = block(
416421
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
417422
joint_attention_kwargs=joint_attention_kwargs,

0 commit comments

Comments
 (0)