Skip to content

Commit cb1b8b2

Browse files
authored
Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)
Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP
1 parent 2791682 commit cb1b8b2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg
366366
hidden_states = self.conv2(hidden_states)
367367

368368
if self.conv_shortcut is not None:
369-
input_tensor = self.conv_shortcut(input_tensor)
369+
input_tensor = self.conv_shortcut(input_tensor.contiguous())
370370

371371
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
372372

0 commit comments

Comments
 (0)