Skip to content

Commit 402f547

Browse files
a-r-r-o-wsayakpaul
authored andcommitted
[core] Freenoise memory improvements (#9262)
* update * implement prompt interpolation * make style * resnet memory optimizations * more memory optimizations; todo: refactor * update * update animatediff controlnet with latest changes * refactor chunked inference changes * remove print statements * update * chunk -> split * remove changes from incorrect conflict resolution * remove changes from incorrect conflict resolution * add explanation of SplitInferenceModule * update docs * Revert "update docs" This reverts commit c55a50a. * update docstring for freenoise split inference * apply suggestions from review * add tests * apply suggestions from review
1 parent e48213c commit 402f547

File tree

5 files changed

+294
-64
lines changed

5 files changed

+294
-64
lines changed

src/diffusers/models/attention.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,8 +1104,26 @@ def forward(
11041104
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
11051105
num_times_accumulated[:, frame_start:frame_end] += weights
11061106

1107-
hidden_states = torch.where(
1108-
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1107+
# TODO(aryan): Maybe this could be done in a better way.
1108+
#
1109+
# Previously, this was:
1110+
# hidden_states = torch.where(
1111+
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1112+
# )
1113+
#
1114+
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1115+
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1116+
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1117+
# looked into this deeply because other memory optimizations led to more pronounced reductions.
1118+
hidden_states = torch.cat(
1119+
[
1120+
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1121+
for accumulated_split, num_times_split in zip(
1122+
accumulated_values.split(self.context_length, dim=1),
1123+
num_times_accumulated.split(self.context_length, dim=1),
1124+
)
1125+
],
1126+
dim=1,
11091127
).to(dtype)
11101128

11111129
# 3. Feed-forward

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -187,20 +187,20 @@ def forward(
187187
hidden_states = self.norm(hidden_states)
188188
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
189189

190-
hidden_states = self.proj_in(hidden_states)
190+
hidden_states = self.proj_in(input=hidden_states)
191191

192192
# 2. Blocks
193193
for block in self.transformer_blocks:
194194
hidden_states = block(
195-
hidden_states,
195+
hidden_states=hidden_states,
196196
encoder_hidden_states=encoder_hidden_states,
197197
timestep=timestep,
198198
cross_attention_kwargs=cross_attention_kwargs,
199199
class_labels=class_labels,
200200
)
201201

202202
# 3. Output
203-
hidden_states = self.proj_out(hidden_states)
203+
hidden_states = self.proj_out(input=hidden_states)
204204
hidden_states = (
205205
hidden_states[None, None, :]
206206
.reshape(batch_size, height, width, num_frames, channel)
@@ -344,15 +344,15 @@ def custom_forward(*inputs):
344344
)
345345

346346
else:
347-
hidden_states = resnet(hidden_states, temb)
347+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
348348

349349
hidden_states = motion_module(hidden_states, num_frames=num_frames)
350350

351351
output_states = output_states + (hidden_states,)
352352

353353
if self.downsamplers is not None:
354354
for downsampler in self.downsamplers:
355-
hidden_states = downsampler(hidden_states)
355+
hidden_states = downsampler(hidden_states=hidden_states)
356356

357357
output_states = output_states + (hidden_states,)
358358

@@ -531,25 +531,18 @@ def custom_forward(*inputs):
531531
temb,
532532
**ckpt_kwargs,
533533
)
534-
hidden_states = attn(
535-
hidden_states,
536-
encoder_hidden_states=encoder_hidden_states,
537-
cross_attention_kwargs=cross_attention_kwargs,
538-
attention_mask=attention_mask,
539-
encoder_attention_mask=encoder_attention_mask,
540-
return_dict=False,
541-
)[0]
542534
else:
543-
hidden_states = resnet(hidden_states, temb)
535+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
536+
537+
hidden_states = attn(
538+
hidden_states=hidden_states,
539+
encoder_hidden_states=encoder_hidden_states,
540+
cross_attention_kwargs=cross_attention_kwargs,
541+
attention_mask=attention_mask,
542+
encoder_attention_mask=encoder_attention_mask,
543+
return_dict=False,
544+
)[0]
544545

545-
hidden_states = attn(
546-
hidden_states,
547-
encoder_hidden_states=encoder_hidden_states,
548-
cross_attention_kwargs=cross_attention_kwargs,
549-
attention_mask=attention_mask,
550-
encoder_attention_mask=encoder_attention_mask,
551-
return_dict=False,
552-
)[0]
553546
hidden_states = motion_module(
554547
hidden_states,
555548
num_frames=num_frames,
@@ -563,7 +556,7 @@ def custom_forward(*inputs):
563556

564557
if self.downsamplers is not None:
565558
for downsampler in self.downsamplers:
566-
hidden_states = downsampler(hidden_states)
559+
hidden_states = downsampler(hidden_states=hidden_states)
567560

568561
output_states = output_states + (hidden_states,)
569562

@@ -757,33 +750,26 @@ def custom_forward(*inputs):
757750
temb,
758751
**ckpt_kwargs,
759752
)
760-
hidden_states = attn(
761-
hidden_states,
762-
encoder_hidden_states=encoder_hidden_states,
763-
cross_attention_kwargs=cross_attention_kwargs,
764-
attention_mask=attention_mask,
765-
encoder_attention_mask=encoder_attention_mask,
766-
return_dict=False,
767-
)[0]
768753
else:
769-
hidden_states = resnet(hidden_states, temb)
754+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
755+
756+
hidden_states = attn(
757+
hidden_states=hidden_states,
758+
encoder_hidden_states=encoder_hidden_states,
759+
cross_attention_kwargs=cross_attention_kwargs,
760+
attention_mask=attention_mask,
761+
encoder_attention_mask=encoder_attention_mask,
762+
return_dict=False,
763+
)[0]
770764

771-
hidden_states = attn(
772-
hidden_states,
773-
encoder_hidden_states=encoder_hidden_states,
774-
cross_attention_kwargs=cross_attention_kwargs,
775-
attention_mask=attention_mask,
776-
encoder_attention_mask=encoder_attention_mask,
777-
return_dict=False,
778-
)[0]
779765
hidden_states = motion_module(
780766
hidden_states,
781767
num_frames=num_frames,
782768
)
783769

784770
if self.upsamplers is not None:
785771
for upsampler in self.upsamplers:
786-
hidden_states = upsampler(hidden_states, upsample_size)
772+
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
787773

788774
return hidden_states
789775

@@ -929,13 +915,13 @@ def custom_forward(*inputs):
929915
create_custom_forward(resnet), hidden_states, temb
930916
)
931917
else:
932-
hidden_states = resnet(hidden_states, temb)
918+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
933919

934920
hidden_states = motion_module(hidden_states, num_frames=num_frames)
935921

936922
if self.upsamplers is not None:
937923
for upsampler in self.upsamplers:
938-
hidden_states = upsampler(hidden_states, upsample_size)
924+
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
939925

940926
return hidden_states
941927

@@ -1080,10 +1066,19 @@ def forward(
10801066
if cross_attention_kwargs.get("scale", None) is not None:
10811067
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
10821068

1083-
hidden_states = self.resnets[0](hidden_states, temb)
1069+
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
10841070

10851071
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
10861072
for attn, resnet, motion_module in blocks:
1073+
hidden_states = attn(
1074+
hidden_states=hidden_states,
1075+
encoder_hidden_states=encoder_hidden_states,
1076+
cross_attention_kwargs=cross_attention_kwargs,
1077+
attention_mask=attention_mask,
1078+
encoder_attention_mask=encoder_attention_mask,
1079+
return_dict=False,
1080+
)[0]
1081+
10871082
if self.training and self.gradient_checkpointing:
10881083

10891084
def create_custom_forward(module, return_dict=None):
@@ -1096,14 +1091,6 @@ def custom_forward(*inputs):
10961091
return custom_forward
10971092

10981093
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1099-
hidden_states = attn(
1100-
hidden_states,
1101-
encoder_hidden_states=encoder_hidden_states,
1102-
cross_attention_kwargs=cross_attention_kwargs,
1103-
attention_mask=attention_mask,
1104-
encoder_attention_mask=encoder_attention_mask,
1105-
return_dict=False,
1106-
)[0]
11071094
hidden_states = torch.utils.checkpoint.checkpoint(
11081095
create_custom_forward(motion_module),
11091096
hidden_states,
@@ -1117,19 +1104,11 @@ def custom_forward(*inputs):
11171104
**ckpt_kwargs,
11181105
)
11191106
else:
1120-
hidden_states = attn(
1121-
hidden_states,
1122-
encoder_hidden_states=encoder_hidden_states,
1123-
cross_attention_kwargs=cross_attention_kwargs,
1124-
attention_mask=attention_mask,
1125-
encoder_attention_mask=encoder_attention_mask,
1126-
return_dict=False,
1127-
)[0]
11281107
hidden_states = motion_module(
11291108
hidden_states,
11301109
num_frames=num_frames,
11311110
)
1132-
hidden_states = resnet(hidden_states, temb)
1111+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
11331112

11341113
return hidden_states
11351114

0 commit comments

Comments
 (0)