|
94 | 94 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
95 | 95 | "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
96 | 96 | "animatediff_rgb": "controlnet_cond_embedding.weight",
|
| 97 | + "auraflow": [ |
| 98 | + "double_layers.0.attn.w2q.weight", |
| 99 | + "double_layers.0.attn.w1q.weight", |
| 100 | + "cond_seq_linear.weight", |
| 101 | + "t_embedder.mlp.0.weight", |
| 102 | + ], |
97 | 103 | "flux": [
|
98 | 104 | "double_blocks.0.img_attn.norm.key_norm.scale",
|
99 | 105 | "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
|
154 | 160 | "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
155 | 161 | "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
156 | 162 | "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
| 163 | + "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"}, |
157 | 164 | "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
158 | 165 | "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
159 | 166 | "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
@@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint):
|
635 | 642 | elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
636 | 643 | model_type = "hunyuan-video"
|
637 | 644 |
|
| 645 | + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]): |
| 646 | + model_type = "auraflow" |
| 647 | + |
638 | 648 | elif (
|
639 | 649 | CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
|
640 | 650 | and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
|
@@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
2090 | 2100 | def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
2091 | 2101 | converted_state_dict = {}
|
2092 | 2102 | keys = list(checkpoint.keys())
|
| 2103 | + |
2093 | 2104 | for k in keys:
|
2094 | 2105 | if "model.diffusion_model." in k:
|
2095 | 2106 | checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
@@ -2689,3 +2700,95 @@ def update_state_dict_(state_dict, old_key, new_key):
|
2689 | 2700 | handler_fn_inplace(key, checkpoint)
|
2690 | 2701 |
|
2691 | 2702 | return checkpoint
|
| 2703 | + |
| 2704 | + |
| 2705 | +def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 2706 | + converted_state_dict = {} |
| 2707 | + state_dict_keys = list(checkpoint.keys()) |
| 2708 | + |
| 2709 | + # Handle register tokens and positional embeddings |
| 2710 | + converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) |
| 2711 | + |
| 2712 | + # Handle time step projection |
| 2713 | + converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) |
| 2714 | + converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) |
| 2715 | + converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) |
| 2716 | + converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) |
| 2717 | + |
| 2718 | + # Handle context embedder |
| 2719 | + converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) |
| 2720 | + |
| 2721 | + # Calculate the number of layers |
| 2722 | + def calculate_layers(keys, key_prefix): |
| 2723 | + layers = set() |
| 2724 | + for k in keys: |
| 2725 | + if key_prefix in k: |
| 2726 | + layer_num = int(k.split(".")[1]) # get the layer number |
| 2727 | + layers.add(layer_num) |
| 2728 | + return len(layers) |
| 2729 | + |
| 2730 | + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") |
| 2731 | + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") |
| 2732 | + |
| 2733 | + # MMDiT blocks |
| 2734 | + for i in range(mmdit_layers): |
| 2735 | + # Feed-forward |
| 2736 | + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} |
| 2737 | + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} |
| 2738 | + for orig_k, diffuser_k in path_mapping.items(): |
| 2739 | + for k, v in weight_mapping.items(): |
| 2740 | + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( |
| 2741 | + f"double_layers.{i}.{orig_k}.{k}.weight", None |
| 2742 | + ) |
| 2743 | + |
| 2744 | + # Norms |
| 2745 | + path_mapping = {"modX": "norm1", "modC": "norm1_context"} |
| 2746 | + for orig_k, diffuser_k in path_mapping.items(): |
| 2747 | + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( |
| 2748 | + f"double_layers.{i}.{orig_k}.1.weight", None |
| 2749 | + ) |
| 2750 | + |
| 2751 | + # Attentions |
| 2752 | + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} |
| 2753 | + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} |
| 2754 | + for attn_mapping in [x_attn_mapping, context_attn_mapping]: |
| 2755 | + for k, v in attn_mapping.items(): |
| 2756 | + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( |
| 2757 | + f"double_layers.{i}.attn.{k}.weight", None |
| 2758 | + ) |
| 2759 | + |
| 2760 | + # Single-DiT blocks |
| 2761 | + for i in range(single_dit_layers): |
| 2762 | + # Feed-forward |
| 2763 | + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} |
| 2764 | + for k, v in mapping.items(): |
| 2765 | + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( |
| 2766 | + f"single_layers.{i}.mlp.{k}.weight", None |
| 2767 | + ) |
| 2768 | + |
| 2769 | + # Norms |
| 2770 | + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( |
| 2771 | + f"single_layers.{i}.modCX.1.weight", None |
| 2772 | + ) |
| 2773 | + |
| 2774 | + # Attentions |
| 2775 | + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} |
| 2776 | + for k, v in x_attn_mapping.items(): |
| 2777 | + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( |
| 2778 | + f"single_layers.{i}.attn.{k}.weight", None |
| 2779 | + ) |
| 2780 | + # Final blocks |
| 2781 | + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) |
| 2782 | + |
| 2783 | + # Handle the final norm layer |
| 2784 | + norm_weight = checkpoint.pop("modF.1.weight", None) |
| 2785 | + if norm_weight is not None: |
| 2786 | + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) |
| 2787 | + else: |
| 2788 | + converted_state_dict["norm_out.linear.weight"] = None |
| 2789 | + |
| 2790 | + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") |
| 2791 | + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") |
| 2792 | + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") |
| 2793 | + |
| 2794 | + return converted_state_dict |
0 commit comments