Skip to content

Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) #3404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 137 additions & 36 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn.functional as F
from torch import nn

from ..utils import is_torch_version
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
Expand Down Expand Up @@ -866,13 +867,27 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
Expand Down Expand Up @@ -957,7 +972,14 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -1361,7 +1383,14 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -1558,7 +1587,14 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -1653,14 +1689,29 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
Expand Down Expand Up @@ -1874,13 +1925,27 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
Expand Down Expand Up @@ -1960,7 +2025,14 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -2388,7 +2460,14 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -2593,7 +2672,14 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -2714,14 +2800,29 @@ def custom_forward(*inputs):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)[0]
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
Expand Down
46 changes: 34 additions & 12 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

from ..utils import BaseOutput, randn_tensor
from ..utils import BaseOutput, is_torch_version, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block


Expand Down Expand Up @@ -117,11 +117,20 @@ def custom_forward(*inputs):
return custom_forward

# down
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)

else:
# down
Expand Down Expand Up @@ -221,13 +230,26 @@ def custom_forward(*inputs):

return custom_forward

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = sample.to(upscale_dtype)
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
sample = sample.to(upscale_dtype)

# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, use_reentrant=False
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = sample.to(upscale_dtype)

# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# middle
sample = self.mid_block(sample)
Expand Down
Loading