Skip to content

Commit 15f1bab

Browse files
Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404)
* gradient checkpointing bug fix * bug fix; changes for reviews * reformat * reformat --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 415c616 commit 15f1bab

File tree

3 files changed

+230
-65
lines changed

3 files changed

+230
-65
lines changed

src/diffusers/models/unet_2d_blocks.py

Lines changed: 137 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn.functional as F
1919
from torch import nn
2020

21+
from ..utils import is_torch_version
2122
from .attention import AdaGroupNorm
2223
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
2324
from .dual_transformer_2d import DualTransformer2DModel
@@ -866,13 +867,27 @@ def custom_forward(*inputs):
866867

867868
return custom_forward
868869

869-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
870-
hidden_states = torch.utils.checkpoint.checkpoint(
871-
create_custom_forward(attn, return_dict=False),
872-
hidden_states,
873-
encoder_hidden_states,
874-
cross_attention_kwargs,
875-
)[0]
870+
if is_torch_version(">=", "1.11.0"):
871+
hidden_states = torch.utils.checkpoint.checkpoint(
872+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
873+
)
874+
hidden_states = torch.utils.checkpoint.checkpoint(
875+
create_custom_forward(attn, return_dict=False),
876+
hidden_states,
877+
encoder_hidden_states,
878+
cross_attention_kwargs,
879+
use_reentrant=False,
880+
)[0]
881+
else:
882+
hidden_states = torch.utils.checkpoint.checkpoint(
883+
create_custom_forward(resnet), hidden_states, temb
884+
)
885+
hidden_states = torch.utils.checkpoint.checkpoint(
886+
create_custom_forward(attn, return_dict=False),
887+
hidden_states,
888+
encoder_hidden_states,
889+
cross_attention_kwargs,
890+
)[0]
876891
else:
877892
hidden_states = resnet(hidden_states, temb)
878893
hidden_states = attn(
@@ -957,7 +972,14 @@ def custom_forward(*inputs):
957972

958973
return custom_forward
959974

960-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
975+
if is_torch_version(">=", "1.11.0"):
976+
hidden_states = torch.utils.checkpoint.checkpoint(
977+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
978+
)
979+
else:
980+
hidden_states = torch.utils.checkpoint.checkpoint(
981+
create_custom_forward(resnet), hidden_states, temb
982+
)
961983
else:
962984
hidden_states = resnet(hidden_states, temb)
963985

@@ -1361,7 +1383,14 @@ def custom_forward(*inputs):
13611383

13621384
return custom_forward
13631385

1364-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1386+
if is_torch_version(">=", "1.11.0"):
1387+
hidden_states = torch.utils.checkpoint.checkpoint(
1388+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1389+
)
1390+
else:
1391+
hidden_states = torch.utils.checkpoint.checkpoint(
1392+
create_custom_forward(resnet), hidden_states, temb
1393+
)
13651394
else:
13661395
hidden_states = resnet(hidden_states, temb)
13671396

@@ -1558,7 +1587,14 @@ def custom_forward(*inputs):
15581587

15591588
return custom_forward
15601589

1561-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1590+
if is_torch_version(">=", "1.11.0"):
1591+
hidden_states = torch.utils.checkpoint.checkpoint(
1592+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1593+
)
1594+
else:
1595+
hidden_states = torch.utils.checkpoint.checkpoint(
1596+
create_custom_forward(resnet), hidden_states, temb
1597+
)
15621598
else:
15631599
hidden_states = resnet(hidden_states, temb)
15641600

@@ -1653,14 +1689,29 @@ def custom_forward(*inputs):
16531689

16541690
return custom_forward
16551691

1656-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1657-
hidden_states = torch.utils.checkpoint.checkpoint(
1658-
create_custom_forward(attn, return_dict=False),
1659-
hidden_states,
1660-
encoder_hidden_states,
1661-
attention_mask,
1662-
cross_attention_kwargs,
1663-
)
1692+
if is_torch_version(">=", "1.11.0"):
1693+
hidden_states = torch.utils.checkpoint.checkpoint(
1694+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1695+
)
1696+
hidden_states = torch.utils.checkpoint.checkpoint(
1697+
create_custom_forward(attn, return_dict=False),
1698+
hidden_states,
1699+
encoder_hidden_states,
1700+
attention_mask,
1701+
cross_attention_kwargs,
1702+
use_reentrant=False,
1703+
)
1704+
else:
1705+
hidden_states = torch.utils.checkpoint.checkpoint(
1706+
create_custom_forward(resnet), hidden_states, temb
1707+
)
1708+
hidden_states = torch.utils.checkpoint.checkpoint(
1709+
create_custom_forward(attn, return_dict=False),
1710+
hidden_states,
1711+
encoder_hidden_states,
1712+
attention_mask,
1713+
cross_attention_kwargs,
1714+
)
16641715
else:
16651716
hidden_states = resnet(hidden_states, temb)
16661717
hidden_states = attn(
@@ -1874,13 +1925,27 @@ def custom_forward(*inputs):
18741925

18751926
return custom_forward
18761927

1877-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1878-
hidden_states = torch.utils.checkpoint.checkpoint(
1879-
create_custom_forward(attn, return_dict=False),
1880-
hidden_states,
1881-
encoder_hidden_states,
1882-
cross_attention_kwargs,
1883-
)[0]
1928+
if is_torch_version(">=", "1.11.0"):
1929+
hidden_states = torch.utils.checkpoint.checkpoint(
1930+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1931+
)
1932+
hidden_states = torch.utils.checkpoint.checkpoint(
1933+
create_custom_forward(attn, return_dict=False),
1934+
hidden_states,
1935+
encoder_hidden_states,
1936+
cross_attention_kwargs,
1937+
use_reentrant=False,
1938+
)[0]
1939+
else:
1940+
hidden_states = torch.utils.checkpoint.checkpoint(
1941+
create_custom_forward(resnet), hidden_states, temb
1942+
)
1943+
hidden_states = torch.utils.checkpoint.checkpoint(
1944+
create_custom_forward(attn, return_dict=False),
1945+
hidden_states,
1946+
encoder_hidden_states,
1947+
cross_attention_kwargs,
1948+
)[0]
18841949
else:
18851950
hidden_states = resnet(hidden_states, temb)
18861951
hidden_states = attn(
@@ -1960,7 +2025,14 @@ def custom_forward(*inputs):
19602025

19612026
return custom_forward
19622027

1963-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
2028+
if is_torch_version(">=", "1.11.0"):
2029+
hidden_states = torch.utils.checkpoint.checkpoint(
2030+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
2031+
)
2032+
else:
2033+
hidden_states = torch.utils.checkpoint.checkpoint(
2034+
create_custom_forward(resnet), hidden_states, temb
2035+
)
19642036
else:
19652037
hidden_states = resnet(hidden_states, temb)
19662038

@@ -2388,7 +2460,14 @@ def custom_forward(*inputs):
23882460

23892461
return custom_forward
23902462

2391-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
2463+
if is_torch_version(">=", "1.11.0"):
2464+
hidden_states = torch.utils.checkpoint.checkpoint(
2465+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
2466+
)
2467+
else:
2468+
hidden_states = torch.utils.checkpoint.checkpoint(
2469+
create_custom_forward(resnet), hidden_states, temb
2470+
)
23922471
else:
23932472
hidden_states = resnet(hidden_states, temb)
23942473

@@ -2593,7 +2672,14 @@ def custom_forward(*inputs):
25932672

25942673
return custom_forward
25952674

2596-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
2675+
if is_torch_version(">=", "1.11.0"):
2676+
hidden_states = torch.utils.checkpoint.checkpoint(
2677+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
2678+
)
2679+
else:
2680+
hidden_states = torch.utils.checkpoint.checkpoint(
2681+
create_custom_forward(resnet), hidden_states, temb
2682+
)
25972683
else:
25982684
hidden_states = resnet(hidden_states, temb)
25992685

@@ -2714,14 +2800,29 @@ def custom_forward(*inputs):
27142800

27152801
return custom_forward
27162802

2717-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
2718-
hidden_states = torch.utils.checkpoint.checkpoint(
2719-
create_custom_forward(attn, return_dict=False),
2720-
hidden_states,
2721-
encoder_hidden_states,
2722-
attention_mask,
2723-
cross_attention_kwargs,
2724-
)[0]
2803+
if is_torch_version(">=", "1.11.0"):
2804+
hidden_states = torch.utils.checkpoint.checkpoint(
2805+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
2806+
)
2807+
hidden_states = torch.utils.checkpoint.checkpoint(
2808+
create_custom_forward(attn, return_dict=False),
2809+
hidden_states,
2810+
encoder_hidden_states,
2811+
attention_mask,
2812+
cross_attention_kwargs,
2813+
use_reentrant=False,
2814+
)[0]
2815+
else:
2816+
hidden_states = torch.utils.checkpoint.checkpoint(
2817+
create_custom_forward(resnet), hidden_states, temb
2818+
)
2819+
hidden_states = torch.utils.checkpoint.checkpoint(
2820+
create_custom_forward(attn, return_dict=False),
2821+
hidden_states,
2822+
encoder_hidden_states,
2823+
attention_mask,
2824+
cross_attention_kwargs,
2825+
)[0]
27252826
else:
27262827
hidden_states = resnet(hidden_states, temb)
27272828
hidden_states = attn(

src/diffusers/models/vae.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020

21-
from ..utils import BaseOutput, randn_tensor
21+
from ..utils import BaseOutput, is_torch_version, randn_tensor
2222
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
2323

2424

@@ -117,11 +117,20 @@ def custom_forward(*inputs):
117117
return custom_forward
118118

119119
# down
120-
for down_block in self.down_blocks:
121-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
122-
123-
# middle
124-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
120+
if is_torch_version(">=", "1.11.0"):
121+
for down_block in self.down_blocks:
122+
sample = torch.utils.checkpoint.checkpoint(
123+
create_custom_forward(down_block), sample, use_reentrant=False
124+
)
125+
# middle
126+
sample = torch.utils.checkpoint.checkpoint(
127+
create_custom_forward(self.mid_block), sample, use_reentrant=False
128+
)
129+
else:
130+
for down_block in self.down_blocks:
131+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
132+
# middle
133+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
125134

126135
else:
127136
# down
@@ -221,13 +230,26 @@ def custom_forward(*inputs):
221230

222231
return custom_forward
223232

224-
# middle
225-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
226-
sample = sample.to(upscale_dtype)
233+
if is_torch_version(">=", "1.11.0"):
234+
# middle
235+
sample = torch.utils.checkpoint.checkpoint(
236+
create_custom_forward(self.mid_block), sample, use_reentrant=False
237+
)
238+
sample = sample.to(upscale_dtype)
227239

228-
# up
229-
for up_block in self.up_blocks:
230-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
240+
# up
241+
for up_block in self.up_blocks:
242+
sample = torch.utils.checkpoint.checkpoint(
243+
create_custom_forward(up_block), sample, use_reentrant=False
244+
)
245+
else:
246+
# middle
247+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
248+
sample = sample.to(upscale_dtype)
249+
250+
# up
251+
for up_block in self.up_blocks:
252+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
231253
else:
232254
# middle
233255
sample = self.mid_block(sample)

0 commit comments

Comments
 (0)