Skip to content

[performance] investigating FluxPipeline for recompilations on resolution changes #11360

Open
@sayakpaul

Description

@sayakpaul

Similar to #11297, I was investigating potential recompilations for Flux on resolution changes.

Code
from diffusers import FluxTransformer2DModel, FluxPipeline
from diffusers.utils.torch_utils import randn_tensor
import torch.utils.benchmark as benchmark
from contextlib import nullcontext
import argparse

import torch 
torch.fx.experimental._config.use_duck_shape = False

HEIGHT_WIDTH = [(1024, 1024), (1536, 768), (2048, 2048)]

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=1,
    )
    return f"{(t0.blocked_autorange().mean):.3f}"


def prepare_latents(
    batch_size=1,
    num_channels_latents=16,
    height=1024,
    width=1024,
    dtype=torch.bfloat16,
    device="cuda",
):
    vae_scale_factor = 8
    height = 2 * (int(height) // (vae_scale_factor * 2))
    width = 2 * (int(width) // (vae_scale_factor * 2))

    shape = (batch_size, num_channels_latents, height, width)

    latents = randn_tensor(shape, device=device, dtype=dtype)
    latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)

    latent_image_ids = FluxPipeline._prepare_latent_image_ids(
        batch_size, height // 2, width // 2, device, dtype
    )

    return latents, latent_image_ids

def get_conditional_inputs(batch_size, dtype=torch.bfloat16, device="cuda"):
    prompt_embeds = torch.randn(batch_size, 512, 4096, dtype=dtype, device=device)
    pooled_prompt_embeds = torch.randn(batch_size, 768, dtype=dtype, device=device)
    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
    return prompt_embeds, pooled_prompt_embeds, text_ids

def load_transformer(do_compile=False):
    transformer = FluxTransformer2DModel.from_pretrained(
        "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
    ).to("cuda")
    if do_compile:
        transformer = torch.compile(transformer, fullgraph=True, dynamic=True)
    return transformer

def run_inference(transformer, **kwargs):
    _ = transformer(**kwargs)

@torch.no_grad()
def main(transformer, batch_size, height, width):
    latents, latent_image_ids = prepare_latents(batch_size=batch_size, height=height, width=width)
    prompt_embeds, pooled_prompt_embeds, text_ids = get_conditional_inputs(batch_size=batch_size)
    
    timestep = torch.full([1], 1.0, device="cuda", dtype=torch.float32)
    timestep = timestep.expand(latents.shape[0]).to(latents.dtype)
    timestep = timestep / 1000
    guidance = torch.full([1], 4.5, device="cuda", dtype=torch.float32)
    guidance = guidance.expand(latents.shape[0])

    input_dict = {
        "hidden_states": latents,
        "timestep": timestep,
        "guidance": guidance,
        "pooled_projections": pooled_prompt_embeds,
        "encoder_hidden_states": prompt_embeds,
        "txt_ids": text_ids,
        "img_ids": latent_image_ids
    }

    run_inference(transformer, **input_dict)
    # time = benchmark_fn(run_inference, transformer, **input_dict)
    # print(f"{height}x{width}: {time} secs")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=1, type=int)
    parser.add_argument("--compile", action="store_true")
    args = parser.parse_args()

    transformer = load_transformer(args.compile)
    context = torch._dynamo.config.patch(error_on_recompile=True) if args.compile else nullcontext()
    with context:
        for height, width in HEIGHT_WIDTH:
            main(transformer=transformer, batch_size=args.batch_size, height=height, width=width)

It currently fails when run with python check_flux_recompilation.py --compile:

Trace
Traceback (most recent call last):
  File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 99, in <module>
    main(transformer=transformer, batch_size=args.batch_size, height=height, width=width)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 82, in main
    run_inference(transformer, **input_dict)
  File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 59, in run_inference
    _ = transformer(**kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 675, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1583, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1558, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/__init__.py", line 2365, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2199, in compile_fx
    return aot_autograd(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 106, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1175, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 850, in load
    compiled_fn = dispatch_and_compile()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1160, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 574, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 834, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 240, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1980, in fw_compiler_base
    _recursive_joint_graph_passes(gm)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 402, in _recursive_joint_graph_passes
    joint_graph_passes(gm)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 544, in joint_graph_passes
    GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/fx/passes/graph_transform_observer.py", line 85, in apply_graph_pass
    return pass_fn(self.gm.graph)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 975, in remove_noop_ops
    if same_meta(node, src) and cond(*args, **kwargs):
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 804, in same_meta
    and statically_known_true(sym_eq(val1.size(), val2.size()))
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'SymFloat' object has no attribute 'size'

My env:

- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.8.0.dev20250417+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.52.0.dev0
- Accelerate version: 1.4.0.dev0
- PEFT version: 0.15.2.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@StrongerXi, @anijain2305 would you have any pointers?

Metadata

Metadata

Assignees

No one assigned

    Labels

    performanceAnything related to performance improvements, profiling and benchmarkingtorch.compile

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions