Skip to content

Torch compile graph fix #3286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn.functional as F
from torch import nn

from ..utils import maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
Expand Down Expand Up @@ -193,6 +194,7 @@ def forward(self, hidden_states):
return hidden_states


@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn.functional as F
from torch import nn

from ..utils import deprecate, logging
from ..utils import deprecate, logging, maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available


Expand All @@ -31,6 +31,7 @@
xformers = None


@maybe_allow_in_graph
class Attention(nn.Module):
r"""
A cross attention layer.
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:

def get_parameter_dtype(parameter: torch.nn.Module):
try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).dtype
params = tuple(parameter.parameters())
if len(params) > 0:
return params[0].dtype

buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype

except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5

Expand Down
25 changes: 14 additions & 11 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,8 @@ def forward(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)

return hidden_states
Expand Down Expand Up @@ -868,15 +869,16 @@ def custom_forward(*inputs):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do these changes as well to get torch compile to work


if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

return hidden_states, output_states

Expand Down Expand Up @@ -949,13 +951,13 @@ def custom_forward(*inputs):
else:
hidden_states = resnet(hidden_states, temb)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

return hidden_states, output_states

Expand Down Expand Up @@ -1342,13 +1344,13 @@ def custom_forward(*inputs):
else:
hidden_states = resnet(hidden_states, temb)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

return hidden_states, output_states

Expand Down Expand Up @@ -1466,13 +1468,13 @@ def forward(
**cross_attention_kwargs,
)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)

output_states += (hidden_states,)
output_states = output_states + (hidden_states,)

return hidden_states, output_states

Expand Down Expand Up @@ -1859,7 +1861,8 @@ def custom_forward(*inputs):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

if self.upsamplers is not None:
for upsampler in self.upsamplers:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def forward(
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)

Expand All @@ -697,7 +697,7 @@ def forward(
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)

class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)

if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def run_safety_checker(self, image, device, dtype):

def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -683,15 +683,16 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -805,8 +806,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand All @@ -829,7 +830,7 @@ def __call__(

# 11. Apply watermark
if self.watermarker is not None:
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
elif output_type == "pt":
nsfw_detected = None
watermark_detected = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def run_safety_checker(self, image, device, dtype):

def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -686,15 +686,16 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_p
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down
Loading