Skip to content

Commit cb342b7

Browse files
AstraliteHearthlky
andauthored
Add AuraFlow GGUF support (#10463)
* Add support for loading AuraFlow models from GGUF https://huggingface.co/city96/AuraFlow-v0.3-gguf * Update AuraFlow documentation for GGUF, add GGUF tests and model detection. * Address code review comments. * Remove unused config. --------- Co-authored-by: hlky <[email protected]>
1 parent 80fd926 commit cb342b7

File tree

6 files changed

+218
-3
lines changed

6 files changed

+218
-3
lines changed

docs/source/en/api/pipelines/aura_flow.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,33 @@ image = pipeline(prompt).images[0]
6262
image.save("auraflow.png")
6363
```
6464

65+
Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
66+
67+
```py
68+
import torch
69+
from diffusers import (
70+
AuraFlowPipeline,
71+
GGUFQuantizationConfig,
72+
AuraFlowTransformer2DModel,
73+
)
74+
75+
transformer = AuraFlowTransformer2DModel.from_single_file(
76+
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
77+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
78+
torch_dtype=torch.bfloat16,
79+
)
80+
81+
pipeline = AuraFlowPipeline.from_pretrained(
82+
"fal/AuraFlow-v0.3",
83+
transformer=transformer,
84+
torch_dtype=torch.bfloat16,
85+
)
86+
87+
prompt = "a cute pony in a field of flowers"
88+
image = pipeline(prompt).images[0]
89+
image.save("auraflow.png")
90+
```
91+
6592
## AuraFlowPipeline
6693

6794
[[autodoc]] AuraFlowPipeline

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .single_file_utils import (
2626
SingleFileComponentError,
2727
convert_animatediff_checkpoint_to_diffusers,
28+
convert_auraflow_transformer_checkpoint_to_diffusers,
2829
convert_autoencoder_dc_checkpoint_to_diffusers,
2930
convert_controlnet_checkpoint,
3031
convert_flux_transformer_checkpoint_to_diffusers,
@@ -106,6 +107,10 @@
106107
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
107108
"default_subfolder": "transformer",
108109
},
110+
"AuraFlowTransformer2DModel": {
111+
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
112+
"default_subfolder": "transformer",
113+
},
109114
}
110115

111116

src/diffusers/loaders/single_file_utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@
9494
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
9595
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
9696
"animatediff_rgb": "controlnet_cond_embedding.weight",
97+
"auraflow": [
98+
"double_layers.0.attn.w2q.weight",
99+
"double_layers.0.attn.w1q.weight",
100+
"cond_seq_linear.weight",
101+
"t_embedder.mlp.0.weight",
102+
],
97103
"flux": [
98104
"double_blocks.0.img_attn.norm.key_norm.scale",
99105
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -154,6 +160,7 @@
154160
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
155161
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
156162
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
163+
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
157164
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
158165
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
159166
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
@@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint):
635642
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
636643
model_type = "hunyuan-video"
637644

645+
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
646+
model_type = "auraflow"
647+
638648
elif (
639649
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
640650
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
@@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
20902100
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
20912101
converted_state_dict = {}
20922102
keys = list(checkpoint.keys())
2103+
20932104
for k in keys:
20942105
if "model.diffusion_model." in k:
20952106
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2689,3 +2700,95 @@ def update_state_dict_(state_dict, old_key, new_key):
26892700
handler_fn_inplace(key, checkpoint)
26902701

26912702
return checkpoint
2703+
2704+
2705+
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2706+
converted_state_dict = {}
2707+
state_dict_keys = list(checkpoint.keys())
2708+
2709+
# Handle register tokens and positional embeddings
2710+
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
2711+
2712+
# Handle time step projection
2713+
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
2714+
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
2715+
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
2716+
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
2717+
2718+
# Handle context embedder
2719+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
2720+
2721+
# Calculate the number of layers
2722+
def calculate_layers(keys, key_prefix):
2723+
layers = set()
2724+
for k in keys:
2725+
if key_prefix in k:
2726+
layer_num = int(k.split(".")[1]) # get the layer number
2727+
layers.add(layer_num)
2728+
return len(layers)
2729+
2730+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
2731+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
2732+
2733+
# MMDiT blocks
2734+
for i in range(mmdit_layers):
2735+
# Feed-forward
2736+
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
2737+
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
2738+
for orig_k, diffuser_k in path_mapping.items():
2739+
for k, v in weight_mapping.items():
2740+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
2741+
f"double_layers.{i}.{orig_k}.{k}.weight", None
2742+
)
2743+
2744+
# Norms
2745+
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
2746+
for orig_k, diffuser_k in path_mapping.items():
2747+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
2748+
f"double_layers.{i}.{orig_k}.1.weight", None
2749+
)
2750+
2751+
# Attentions
2752+
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
2753+
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
2754+
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
2755+
for k, v in attn_mapping.items():
2756+
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
2757+
f"double_layers.{i}.attn.{k}.weight", None
2758+
)
2759+
2760+
# Single-DiT blocks
2761+
for i in range(single_dit_layers):
2762+
# Feed-forward
2763+
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
2764+
for k, v in mapping.items():
2765+
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
2766+
f"single_layers.{i}.mlp.{k}.weight", None
2767+
)
2768+
2769+
# Norms
2770+
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
2771+
f"single_layers.{i}.modCX.1.weight", None
2772+
)
2773+
2774+
# Attentions
2775+
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
2776+
for k, v in x_attn_mapping.items():
2777+
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
2778+
f"single_layers.{i}.attn.{k}.weight", None
2779+
)
2780+
# Final blocks
2781+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
2782+
2783+
# Handle the final norm layer
2784+
norm_weight = checkpoint.pop("modF.1.weight", None)
2785+
if norm_weight is not None:
2786+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
2787+
else:
2788+
converted_state_dict["norm_out.linear.weight"] = None
2789+
2790+
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
2791+
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
2792+
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
2793+
2794+
return converted_state_dict

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23+
from ...loaders import FromOriginalModelMixin
2324
from ...utils import is_torch_version, logging
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention_processor import (
@@ -253,7 +254,7 @@ def forward(
253254
return encoder_hidden_states, hidden_states
254255

255256

256-
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
257+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257258
r"""
258259
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
259260

src/diffusers/quantizers/gguf/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def __init__(
450450
def forward(self, inputs):
451451
weight = dequantize_gguf_tensor(self.weight)
452452
weight = weight.to(self.compute_dtype)
453-
bias = self.bias.to(self.compute_dtype)
453+
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
454454

455455
output = torch.nn.functional.linear(inputs, weight, bias)
456456
return output

tests/quantization/gguf/test_gguf.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.nn as nn
77

88
from diffusers import (
9+
AuraFlowPipeline,
10+
AuraFlowTransformer2DModel,
911
FluxPipeline,
1012
FluxTransformer2DModel,
1113
GGUFQuantizationConfig,
@@ -54,7 +56,8 @@ def test_gguf_linear_layers(self):
5456
for name, module in model.named_modules():
5557
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
5658
assert module.weight.dtype == torch.uint8
57-
assert module.bias.dtype == torch.float32
59+
if module.bias is not None:
60+
assert module.bias.dtype == torch.float32
5861

5962
def test_gguf_memory_usage(self):
6063
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
@@ -377,3 +380,79 @@ def test_pipeline_inference(self):
377380
)
378381
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
379382
assert max_diff < 1e-4
383+
384+
385+
class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
386+
ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf"
387+
torch_dtype = torch.bfloat16
388+
model_cls = AuraFlowTransformer2DModel
389+
expected_memory_use_in_gb = 4
390+
391+
def setUp(self):
392+
gc.collect()
393+
torch.cuda.empty_cache()
394+
395+
def tearDown(self):
396+
gc.collect()
397+
torch.cuda.empty_cache()
398+
399+
def get_dummy_inputs(self):
400+
return {
401+
"hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
402+
torch_device, self.torch_dtype
403+
),
404+
"encoder_hidden_states": torch.randn(
405+
(1, 512, 2048),
406+
generator=torch.Generator("cpu").manual_seed(0),
407+
).to(torch_device, self.torch_dtype),
408+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
409+
}
410+
411+
def test_pipeline_inference(self):
412+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
413+
transformer = self.model_cls.from_single_file(
414+
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
415+
)
416+
pipe = AuraFlowPipeline.from_pretrained(
417+
"fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype
418+
)
419+
pipe.enable_model_cpu_offload()
420+
421+
prompt = "a pony holding a sign that says hello"
422+
output = pipe(
423+
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
424+
).images[0]
425+
output_slice = output[:3, :3, :].flatten()
426+
expected_slice = np.array(
427+
[
428+
0.46484375,
429+
0.546875,
430+
0.64453125,
431+
0.48242188,
432+
0.53515625,
433+
0.59765625,
434+
0.47070312,
435+
0.5078125,
436+
0.5703125,
437+
0.42773438,
438+
0.50390625,
439+
0.5703125,
440+
0.47070312,
441+
0.515625,
442+
0.57421875,
443+
0.45898438,
444+
0.48632812,
445+
0.53515625,
446+
0.4453125,
447+
0.5078125,
448+
0.56640625,
449+
0.47851562,
450+
0.5234375,
451+
0.57421875,
452+
0.48632812,
453+
0.5234375,
454+
0.56640625,
455+
]
456+
)
457+
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
458+
assert max_diff < 1e-4

0 commit comments

Comments
 (0)