Skip to content

Commit 4890862

Browse files
author
Jimmy
committed
Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs
1 parent d55f411 commit 4890862

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,8 @@ def forward(
440440

441441
# 5. Output norm, projection & unpatchify
442442
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
443+
shift = shift.to(hidden_states.device)
444+
scale = scale.to(hidden_states.device)
443445
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
444446
hidden_states = self.proj_out(hidden_states)
445447

0 commit comments

Comments
 (0)