Skip to content

Commit c4d4ac2

Browse files
authored
Refactor gradient checkpointing (#10611)
* update * remove unused fn * apply suggestions based on review * update + cleanup 🧹 * more cleanup 🧹 * make fix-copies * update test
1 parent f295e2e commit c4d4ac2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+309
-1790
lines changed

examples/community/matryoshka.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
USE_PEFT_BACKEND,
8181
BaseOutput,
8282
deprecate,
83-
is_torch_version,
8483
is_torch_xla_available,
8584
logging,
8685
replace_example_docstring,
@@ -869,23 +868,7 @@ def forward(
869868

870869
for i, (resnet, attn) in enumerate(blocks):
871870
if torch.is_grad_enabled() and self.gradient_checkpointing:
872-
873-
def create_custom_forward(module, return_dict=None):
874-
def custom_forward(*inputs):
875-
if return_dict is not None:
876-
return module(*inputs, return_dict=return_dict)
877-
else:
878-
return module(*inputs)
879-
880-
return custom_forward
881-
882-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883-
hidden_states = torch.utils.checkpoint.checkpoint(
884-
create_custom_forward(resnet),
885-
hidden_states,
886-
temb,
887-
**ckpt_kwargs,
888-
)
871+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
889872
hidden_states = attn(
890873
hidden_states,
891874
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ def forward(
10301013
hidden_states = self.resnets[0](hidden_states, temb)
10311014
for attn, resnet in zip(self.attentions, self.resnets[1:]):
10321015
if torch.is_grad_enabled() and self.gradient_checkpointing:
1033-
1034-
def create_custom_forward(module, return_dict=None):
1035-
def custom_forward(*inputs):
1036-
if return_dict is not None:
1037-
return module(*inputs, return_dict=return_dict)
1038-
else:
1039-
return module(*inputs)
1040-
1041-
return custom_forward
1042-
1043-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
10441016
hidden_states = attn(
10451017
hidden_states,
10461018
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
10491021
encoder_attention_mask=encoder_attention_mask,
10501022
return_dict=False,
10511023
)[0]
1052-
hidden_states = torch.utils.checkpoint.checkpoint(
1053-
create_custom_forward(resnet),
1054-
hidden_states,
1055-
temb,
1056-
**ckpt_kwargs,
1057-
)
1024+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
10581025
else:
10591026
hidden_states = attn(
10601027
hidden_states,
@@ -1192,23 +1159,7 @@ def forward(
11921159
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11931160

11941161
if torch.is_grad_enabled() and self.gradient_checkpointing:
1195-
1196-
def create_custom_forward(module, return_dict=None):
1197-
def custom_forward(*inputs):
1198-
if return_dict is not None:
1199-
return module(*inputs, return_dict=return_dict)
1200-
else:
1201-
return module(*inputs)
1202-
1203-
return custom_forward
1204-
1205-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1206-
hidden_states = torch.utils.checkpoint.checkpoint(
1207-
create_custom_forward(resnet),
1208-
hidden_states,
1209-
temb,
1210-
**ckpt_kwargs,
1211-
)
1162+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
12121163
hidden_states = attn(
12131164
hidden_states,
12141165
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ def __init__(
12821233
]
12831234
)
12841235

1285-
def _set_gradient_checkpointing(self, module, value=False):
1286-
if hasattr(module, "gradient_checkpointing"):
1287-
module.gradient_checkpointing = value
1288-
12891236
def forward(
12901237
self,
12911238
hidden_states: torch.Tensor,
@@ -1365,27 +1312,15 @@ def forward(
13651312
# Blocks
13661313
for block in self.transformer_blocks:
13671314
if torch.is_grad_enabled() and self.gradient_checkpointing:
1368-
1369-
def create_custom_forward(module, return_dict=None):
1370-
def custom_forward(*inputs):
1371-
if return_dict is not None:
1372-
return module(*inputs, return_dict=return_dict)
1373-
else:
1374-
return module(*inputs)
1375-
1376-
return custom_forward
1377-
1378-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1379-
hidden_states = torch.utils.checkpoint.checkpoint(
1380-
create_custom_forward(block),
1315+
hidden_states = self._gradient_checkpointing_func(
1316+
block,
13811317
hidden_states,
13821318
attention_mask,
13831319
encoder_hidden_states,
13841320
encoder_attention_mask,
13851321
timestep,
13861322
cross_attention_kwargs,
13871323
class_labels,
1388-
**ckpt_kwargs,
13891324
)
13901325
else:
13911326
hidden_states = block(
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
27242659
for module in self.children():
27252660
fn_recursive_set_attention_slice(module, reversed_slice_size)
27262661

2727-
def _set_gradient_checkpointing(self, module, value=False):
2728-
if hasattr(module, "gradient_checkpointing"):
2729-
module.gradient_checkpointing = value
2730-
27312662
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
27322663
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
27332664

examples/research_projects/pixart/controlnet_pixart_alpha.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from diffusers.models.attention import BasicTransformerBlock
99
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1010
from diffusers.models.modeling_utils import ModelMixin
11-
from diffusers.utils.torch_utils import is_torch_version
1211

1312

1413
class PixArtControlNetAdapterBlock(nn.Module):
@@ -151,10 +150,6 @@ def __init__(
151150
self.transformer = transformer
152151
self.controlnet = controlnet
153152

154-
def _set_gradient_checkpointing(self, module, value=False):
155-
if hasattr(module, "gradient_checkpointing"):
156-
module.gradient_checkpointing = value
157-
158153
def forward(
159154
self,
160155
hidden_states: torch.Tensor,
@@ -220,26 +215,15 @@ def forward(
220215
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
221216
exit(1)
222217

223-
def create_custom_forward(module, return_dict=None):
224-
def custom_forward(*inputs):
225-
if return_dict is not None:
226-
return module(*inputs, return_dict=return_dict)
227-
else:
228-
return module(*inputs)
229-
230-
return custom_forward
231-
232-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
233-
hidden_states = torch.utils.checkpoint.checkpoint(
234-
create_custom_forward(block),
218+
hidden_states = self._gradient_checkpointing_func(
219+
block,
235220
hidden_states,
236221
attention_mask,
237222
encoder_hidden_states,
238223
encoder_attention_mask,
239224
timestep,
240225
cross_attention_kwargs,
241226
None,
242-
**ckpt_kwargs,
243227
)
244228
else:
245229
# the control nets are only used for the blocks 1 to self.blocks_num

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ def __init__(
138138
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139139
self.tile_overlap_factor = 0.25
140140

141-
def _set_gradient_checkpointing(self, module, value=False):
142-
if isinstance(module, (Encoder, Decoder)):
143-
module.gradient_checkpointing = value
144-
145141
def enable_tiling(self, use_tiling: bool = True):
146142
r"""
147143
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -507,19 +507,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
507507
sample = sample + residual
508508

509509
if torch.is_grad_enabled() and self.gradient_checkpointing:
510-
511-
def create_custom_forward(module):
512-
def custom_forward(*inputs):
513-
return module(*inputs)
514-
515-
return custom_forward
516-
517510
# Down blocks
518511
for down_block in self.down_blocks:
519-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
512+
sample = self._gradient_checkpointing_func(down_block, sample)
520513

521514
# Mid block
522-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
515+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
523516
else:
524517
# Down blocks
525518
for down_block in self.down_blocks:
@@ -647,19 +640,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
647640
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648641

649642
if torch.is_grad_enabled() and self.gradient_checkpointing:
650-
651-
def create_custom_forward(module):
652-
def custom_forward(*inputs):
653-
return module(*inputs)
654-
655-
return custom_forward
656-
657643
# Mid block
658-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
644+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
659645

660646
# Up blocks
661647
for up_block in self.up_blocks:
662-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
648+
sample = self._gradient_checkpointing_func(up_block, sample)
663649

664650
else:
665651
# Mid block
@@ -809,10 +795,6 @@ def __init__(
809795
sample_size - self.tile_overlap_w,
810796
)
811797

812-
def _set_gradient_checkpointing(self, module, value=False):
813-
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
814-
module.gradient_checkpointing = value
815-
816798
def enable_tiling(self) -> None:
817799
r"""
818800
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -421,15 +421,8 @@ def forward(
421421
conv_cache_key = f"resnet_{i}"
422422

423423
if torch.is_grad_enabled() and self.gradient_checkpointing:
424-
425-
def create_custom_forward(module):
426-
def create_forward(*inputs):
427-
return module(*inputs)
428-
429-
return create_forward
430-
431-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
432-
create_custom_forward(resnet),
424+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
425+
resnet,
433426
hidden_states,
434427
temb,
435428
zq,
@@ -523,15 +516,8 @@ def forward(
523516
conv_cache_key = f"resnet_{i}"
524517

525518
if torch.is_grad_enabled() and self.gradient_checkpointing:
526-
527-
def create_custom_forward(module):
528-
def create_forward(*inputs):
529-
return module(*inputs)
530-
531-
return create_forward
532-
533-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534-
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
519+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
520+
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
535521
)
536522
else:
537523
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -637,15 +623,8 @@ def forward(
637623
conv_cache_key = f"resnet_{i}"
638624

639625
if torch.is_grad_enabled() and self.gradient_checkpointing:
640-
641-
def create_custom_forward(module):
642-
def create_forward(*inputs):
643-
return module(*inputs)
644-
645-
return create_forward
646-
647-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
648-
create_custom_forward(resnet),
626+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
627+
resnet,
649628
hidden_states,
650629
temb,
651630
zq,
@@ -774,27 +753,20 @@ def forward(
774753
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775754

776755
if torch.is_grad_enabled() and self.gradient_checkpointing:
777-
778-
def create_custom_forward(module):
779-
def custom_forward(*inputs):
780-
return module(*inputs)
781-
782-
return custom_forward
783-
784756
# 1. Down
785757
for i, down_block in enumerate(self.down_blocks):
786758
conv_cache_key = f"down_block_{i}"
787-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
788-
create_custom_forward(down_block),
759+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
760+
down_block,
789761
hidden_states,
790762
temb,
791763
None,
792764
conv_cache.get(conv_cache_key),
793765
)
794766

795767
# 2. Mid
796-
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
797-
create_custom_forward(self.mid_block),
768+
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
769+
self.mid_block,
798770
hidden_states,
799771
temb,
800772
None,
@@ -940,16 +912,9 @@ def forward(
940912
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941913

942914
if torch.is_grad_enabled() and self.gradient_checkpointing:
943-
944-
def create_custom_forward(module):
945-
def custom_forward(*inputs):
946-
return module(*inputs)
947-
948-
return custom_forward
949-
950915
# 1. Mid
951-
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
952-
create_custom_forward(self.mid_block),
916+
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
917+
self.mid_block,
953918
hidden_states,
954919
temb,
955920
sample,
@@ -959,8 +924,8 @@ def custom_forward(*inputs):
959924
# 2. Up
960925
for i, up_block in enumerate(self.up_blocks):
961926
conv_cache_key = f"up_block_{i}"
962-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
963-
create_custom_forward(up_block),
927+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
928+
up_block,
964929
hidden_states,
965930
temb,
966931
sample,
@@ -1122,10 +1087,6 @@ def __init__(
11221087
self.tile_overlap_factor_height = 1 / 6
11231088
self.tile_overlap_factor_width = 1 / 5
11241089

1125-
def _set_gradient_checkpointing(self, module, value=False):
1126-
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1127-
module.gradient_checkpointing = value
1128-
11291090
def enable_tiling(
11301091
self,
11311092
tile_sample_min_height: Optional[int] = None,

0 commit comments

Comments
 (0)