Skip to content

Commit 053a48d

Browse files
linjiaproyiyixuxusayakpaul
committed
Fix a bug for SD35 control net training and improve control net block index (#10065)
* wip --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent ba0f59b commit 053a48d

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,19 @@ def custom_forward(*inputs):
393393
return custom_forward
394394

395395
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
396-
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
397-
create_custom_forward(block),
398-
hidden_states,
399-
encoder_hidden_states,
400-
temb,
401-
**ckpt_kwargs,
402-
)
396+
if self.context_embedder is not None:
397+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
398+
create_custom_forward(block),
399+
hidden_states,
400+
encoder_hidden_states,
401+
temb,
402+
**ckpt_kwargs,
403+
)
404+
else:
405+
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
406+
hidden_states = torch.utils.checkpoint.checkpoint(
407+
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
408+
)
403409

404410
else:
405411
if self.context_embedder is not None:

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

18-
import numpy as np
1918
import torch
2019
import torch.nn as nn
2120
import torch.nn.functional as F
@@ -424,8 +423,7 @@ def custom_forward(*inputs):
424423
# controlnet residual
425424
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
426425
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
427-
interval_control = int(np.ceil(interval_control))
428-
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
426+
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
429427

430428
hidden_states = self.norm_out(hidden_states, temb)
431429
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)