Open
Description
Similar to #11297, I was investigating potential recompilations for Flux on resolution changes.
Code
from diffusers import FluxTransformer2DModel, FluxPipeline
from diffusers.utils.torch_utils import randn_tensor
import torch.utils.benchmark as benchmark
from contextlib import nullcontext
import argparse
import torch
torch.fx.experimental._config.use_duck_shape = False
HEIGHT_WIDTH = [(1024, 1024), (1536, 768), (2048, 2048)]
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=1,
)
return f"{(t0.blocked_autorange().mean):.3f}"
def prepare_latents(
batch_size=1,
num_channels_latents=16,
height=1024,
width=1024,
dtype=torch.bfloat16,
device="cuda",
):
vae_scale_factor = 8
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latents = randn_tensor(shape, device=device, dtype=dtype)
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
batch_size, height // 2, width // 2, device, dtype
)
return latents, latent_image_ids
def get_conditional_inputs(batch_size, dtype=torch.bfloat16, device="cuda"):
prompt_embeds = torch.randn(batch_size, 512, 4096, dtype=dtype, device=device)
pooled_prompt_embeds = torch.randn(batch_size, 768, dtype=dtype, device=device)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
def load_transformer(do_compile=False):
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
).to("cuda")
if do_compile:
transformer = torch.compile(transformer, fullgraph=True, dynamic=True)
return transformer
def run_inference(transformer, **kwargs):
_ = transformer(**kwargs)
@torch.no_grad()
def main(transformer, batch_size, height, width):
latents, latent_image_ids = prepare_latents(batch_size=batch_size, height=height, width=width)
prompt_embeds, pooled_prompt_embeds, text_ids = get_conditional_inputs(batch_size=batch_size)
timestep = torch.full([1], 1.0, device="cuda", dtype=torch.float32)
timestep = timestep.expand(latents.shape[0]).to(latents.dtype)
timestep = timestep / 1000
guidance = torch.full([1], 4.5, device="cuda", dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
input_dict = {
"hidden_states": latents,
"timestep": timestep,
"guidance": guidance,
"pooled_projections": pooled_prompt_embeds,
"encoder_hidden_states": prompt_embeds,
"txt_ids": text_ids,
"img_ids": latent_image_ids
}
run_inference(transformer, **input_dict)
# time = benchmark_fn(run_inference, transformer, **input_dict)
# print(f"{height}x{width}: {time} secs")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()
transformer = load_transformer(args.compile)
context = torch._dynamo.config.patch(error_on_recompile=True) if args.compile else nullcontext()
with context:
for height, width in HEIGHT_WIDTH:
main(transformer=transformer, batch_size=args.batch_size, height=height, width=width)
It currently fails when run with python check_flux_recompilation.py --compile
:
Trace
Traceback (most recent call last):
File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 99, in <module>
main(transformer=transformer, batch_size=args.batch_size, height=height, width=width)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 82, in main
run_inference(transformer, **input_dict)
File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 59, in run_inference
_ = transformer(**kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 675, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1583, in _call_user_compiler
raise BackendCompilerFailed(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1558, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/__init__.py", line 2365, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2199, in compile_fx
return aot_autograd(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 106, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1175, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 850, in load
compiled_fn = dispatch_and_compile()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1160, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 574, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 834, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 240, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
return self.compiler_fn(gm, example_inputs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1980, in fw_compiler_base
_recursive_joint_graph_passes(gm)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 402, in _recursive_joint_graph_passes
joint_graph_passes(gm)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 544, in joint_graph_passes
GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/fx/passes/graph_transform_observer.py", line 85, in apply_graph_pass
return pass_fn(self.gm.graph)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 975, in remove_noop_ops
if same_meta(node, src) and cond(*args, **kwargs):
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 804, in same_meta
and statically_known_true(sym_eq(val1.size(), val2.size()))
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'SymFloat' object has no attribute 'size'
My env:
- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.8.0.dev20250417+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.52.0.dev0
- Accelerate version: 1.4.0.dev0
- PEFT version: 0.15.2.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
@StrongerXi, @anijain2305 would you have any pointers?