Skip to content

Commit 34ab1af

Browse files
committed
inplace scatter
1 parent 7fb61dd commit 34ab1af

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,10 +1245,10 @@ def __call__(
12451245
)
12461246

12471247
if isinstance(controlnet, ControlNetUnionModel):
1248-
control_type = torch.zeros(controlnet.config.num_control_type).scatter(0, torch.tensor(control_mode), 1)
1248+
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
12491249
elif isinstance(controlnet, MultiControlNetUnionModel):
12501250
control_type = [
1251-
torch.zeros(controlnet_.config.num_control_type).scatter(0, torch.tensor(control_mode_), 1)
1251+
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
12521252
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
12531253
]
12541254

0 commit comments

Comments
 (0)