Skip to content

Commit ebd079a

Browse files
hlkyyiyixuxu
authored andcommitted
Support Flux IP Adapter (#10261)
* Flux IP-Adapter * test cfg * make style * temp remove copied from * fix test * fix test * v2 * fix * make style * temp remove copied from * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * Move encoder_hid_proj to inside FluxTransformer2DModel * merge * separate encode_prompt, add copied from, image_encoder offload * make * fix test * fix * Update src/diffusers/pipelines/flux/pipeline_flux.py * test_flux_prompt_embeds change not needed * true_cfg -> true_cfg_scale * fix merge conflict * test_flux_ip_adapter_inference * add fast test * FluxIPAdapterMixin not test mixin * Update pipeline_flux.py Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent e4701c1 commit ebd079a

File tree

12 files changed

+1157
-14
lines changed

12 files changed

+1157
-14
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
from accelerate import init_empty_weights
6+
from huggingface_hub import hf_hub_download
7+
8+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
9+
10+
11+
if is_transformers_available():
12+
from transformers import CLIPVisionModelWithProjection
13+
14+
vision = True
15+
else:
16+
vision = False
17+
18+
"""
19+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
20+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
21+
--filename "flux-ip-adapter.safetensors"
22+
--output_path "flux-ip-adapter-hf/"
23+
"""
24+
25+
26+
CTX = init_empty_weights if is_accelerate_available else nullcontext
27+
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
30+
parser.add_argument("--filename", default="flux.safetensors", type=str)
31+
parser.add_argument("--checkpoint_path", default=None, type=str)
32+
parser.add_argument("--output_path", type=str)
33+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
34+
35+
args = parser.parse_args()
36+
37+
38+
def load_original_checkpoint(args):
39+
if args.original_state_dict_repo_id is not None:
40+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
41+
elif args.checkpoint_path is not None:
42+
ckpt_path = args.checkpoint_path
43+
else:
44+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
45+
46+
original_state_dict = safetensors.torch.load_file(ckpt_path)
47+
return original_state_dict
48+
49+
50+
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
51+
converted_state_dict = {}
52+
53+
# image_proj
54+
## norm
55+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
56+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
57+
## proj
58+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
59+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
60+
61+
# double transformer blocks
62+
for i in range(num_layers):
63+
block_prefix = f"ip_adapter.{i}."
64+
# to_k_ip
65+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
66+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
67+
)
68+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
69+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
70+
)
71+
# to_v_ip
72+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
73+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
74+
)
75+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
76+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
77+
)
78+
79+
return converted_state_dict
80+
81+
82+
def main(args):
83+
original_ckpt = load_original_checkpoint(args)
84+
85+
num_layers = 19
86+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
87+
88+
print("Saving Flux IP-Adapter in Diffusers format.")
89+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
90+
91+
if vision:
92+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
93+
model.save_pretrained(f"{args.output_path}/image_encoder")
94+
95+
96+
if __name__ == "__main__":
97+
main(args)

src/diffusers/loaders/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
58+
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
5959
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
6060
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6161
_import_structure["utils"] = ["AttnProcsLayers"]
@@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
7777
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7878
_import_structure["ip_adapter"] = [
7979
"IPAdapterMixin",
80+
"FluxIPAdapterMixin",
8081
"SD3IPAdapterMixin",
8182
]
8283

@@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder):
8687
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
8788
if is_torch_available():
8889
from .single_file_model import FromOriginalModelMixin
90+
from .transformer_flux import FluxTransformer2DLoadersMixin
8991
from .transformer_sd3 import SD3Transformer2DLoadersMixin
9092
from .unet import UNet2DConditionLoadersMixin
9193
from .utils import AttnProcsLayers
9294

9395
if is_transformers_available():
9496
from .ip_adapter import (
97+
FluxIPAdapterMixin,
9598
IPAdapterMixin,
9699
SD3IPAdapterMixin,
97100
)

0 commit comments

Comments
 (0)