Skip to content

Flux Fill, Canny, Depth, Redux #9985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 23, 2024
Merged

Flux Fill, Canny, Depth, Redux #9985

merged 23 commits into from
Nov 23, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Nov 21, 2024

we've opened PR to add diffusers checkpoint to hub, before they are merged, here is how you can test these 4 models

# test all model card examples for flux redux/fill/canny/depth

# flux redux
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image

pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev", revision="refs/pr/8", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev" , 
    text_encoder=None,
    text_encoder_2=None,
    torch_dtype=torch.bfloat16
).to("cuda")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(
    guidance_scale=2.5,
    num_inference_steps=50,
    generator=torch.Generator("cpu").manual_seed(0),
    **pipe_prior_output,
).images
images[0].save("flux-dev-redux.png")

del pipe_prior_redux
del pipe
import gc
gc.collect()
torch.cuda.empty_cache()

# flux fill
import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, revision="refs/pr/4").to("cuda")
image = pipe(
    prompt="a white paper cup",
    image=image,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save(f"flux-fill-dev.png")

del pipe
import gc
gc.collect()
torch.cuda.empty_cache()

# flux canny

# please install controlnet-aux if you haven't 
# # !pip install -U controlnet-aux

import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image

pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)

image = pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=50,
    guidance_scale=30.0,
).images[0]
image.save("flux-canny-dev.png")

del pipe
import gc
gc.collect()
torch.cuda.empty_cache()

# flux depth

# please install image_gen_aux if you haven't
# !pip install git+https://github.com/asomoza/image_gen_aux.git

import torch
from diffusers import FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor

pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=30,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("flux-depth-dev.png")

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu November 21, 2024 16:49
@yiyixuxu
Copy link
Collaborator

to test fill @asomoza

# flux 
import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image

img = load_image("/raid/yiyi/flux-new/assets/cup.png")
mask = load_image("/raid/yiyi/flux-new/assets/cup_mask.png")

repo_id = "diffusers-internal-dev/dummy-fill"

pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

image = pipe(
    prompt="a white paper cup",
    image=img,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("yiyi_test_2_out.png")

@a-r-r-o-w
Copy link
Member Author

Don't exactly the have the minimal examples with image_gen_aux or controlnet_aux, but hopefully these are helpful to get started:

ControlNet Depth
import numpy as np
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
from einops import repeat
from transformers import AutoModelForDepthEstimation, AutoProcessor
from PIL import Image


cache_dir = "/raid/.cache/huggingface"

transformer = FluxTransformer2DModel.from_pretrained("/raid/aryan/flux1-depth-dev-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir)
pipe.to("cuda")

prompt = "A rabbit made of gold standing strong in a dream-like cosmic world. The buildings, in the background, are made of a pink liquid-like substance but maintain solid structure"
control_image = load_image("inputs/rabbit.jpg")


class DepthImageEncoder:
    depth_model_name = "LiheYoung/depth-anything-large-hf"

    def __init__(self, device):
        self.device = device
        self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
        self.processor = AutoProcessor.from_pretrained(self.depth_model_name)

    def __call__(self, img: torch.Tensor) -> torch.Tensor:
        hw = img.shape[-2:]

        img = torch.clamp(img, -1.0, 1.0)
        img_byte = ((img + 1.0) * 127.5).byte()

        img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
        depth = self.depth_model(img.to(self.device)).predicted_depth
        depth = repeat(depth, "b h w -> b 3 h w")
        depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)

        depth = depth / 127.5 - 1.0
        return depth


processor = DepthImageEncoder(device="cuda")
control_image = torch.from_numpy(np.array(control_image)).unsqueeze(0).permute(0, 3, 1, 2)
control_image = control_image / 127.5 - 1.0
control_image = processor(control_image)
control_image = (control_image + 1) * 127.5
control_image = control_image.to(torch.uint8).clamp(0, 255)
control_image = Image.fromarray(control_image.cpu().permute(0, 2, 3, 1).numpy().astype(np.uint8)[0])
control_image.save("output_depth.png")

image = pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=768,
    num_inference_steps=30,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
ControlNet Canny
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
from PIL import Image


cache_dir = "/raid/.cache/huggingface"

transformer = FluxTransformer2DModel.from_pretrained("/raid/aryan/flux1-canny-dev-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir)
pipe.to("cuda")

prompt = "A rabbit made of gold standing strong in a dream-like cosmic world. The buildings, in the background, are made of a pink liquid-like substance but maintain solid structure"
control_image = load_image("inputs/rabbit.jpg")

if False:
    from controlnet_aux import CannyDetector
    
    
    processor = CannyDetector()
    control_image = processor(control_image)
else:
    import cv2
    import numpy as np
    from einops import rearrange, repeat

    
    class CannyImageEncoder:
        def __init__(
            self,
            device,
            min_t: int = 50,
            max_t: int = 200,
        ):
            self.device = device
            self.min_t = min_t
            self.max_t = max_t

        def __call__(self, img: torch.Tensor) -> torch.Tensor:
            assert img.shape[0] == 1, "Only batch size 1 is supported"

            img = rearrange(img[0], "c h w -> h w c")
            img = torch.clamp(img, -1.0, 1.0)
            img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)

            # Apply Canny edge detection
            canny = cv2.Canny(img_np, self.min_t, self.max_t)

            # Convert back to torch tensor and reshape
            canny = torch.from_numpy(canny).float() / 127.5 - 1.0
            canny = rearrange(canny, "h w -> 1 1 h w")
            canny = repeat(canny, "b 1 ... -> b 3 ...")
            return canny.to(self.device)


    processor = CannyImageEncoder(device="cuda")
    control_image = torch.from_numpy(np.array(control_image)).unsqueeze(0).permute(0, 3, 1, 2)
    control_image = control_image / 127.5 - 1.0
    control_image = processor(control_image)
    control_image = (control_image + 1) * 127.5
    control_image = control_image.to(torch.uint8).clamp(0, 255)
    control_image = Image.fromarray(control_image.cpu().permute(0, 2, 3, 1).numpy().astype(np.uint8)[0])
    control_image.save("output_canny.png")

if False:
    image = pipe(
        prompt=prompt,
        control_image=control_image,
        height=1024,
        width=768,
        num_inference_steps=30,
        guidance_scale=30.0,
        generator=torch.Generator().manual_seed(42),
    ).images[0]
    image.save("output.png")
else:
    latents = torch.load("/raid/aryan/flux-tools-main/original_impl_latents.pt")
    control_latents = torch.load("/raid/aryan/flux-tools-main/original_impl_img_cond.pt")
    print(latents.shape, control_latents.shape)
    image = pipe(
        prompt=prompt,
        height=1024,
        width=768,
        num_inference_steps=30,
        guidance_scale=30.0,
        latents=latents,
        control_latents=control_latents,
        generator=torch.Generator().manual_seed(42),
    ).images[0]
    image.save("output.png")

This is on the DGX so the hardcoded paths should hopefully work. I'll also upload to our internal group to make it easier to test

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w changed the title New Flux ControlNet, Control LoRA, Redux New Flux Fill, ControlNet, Control LoRA, Redux Nov 21, 2024
@@ -529,6 +529,41 @@ def prepare_latents(

return latents, latent_image_ids

# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way is probably to have 3 separate FluxControlPipeline? but maybe @DN6 can come up with better ideas!

Comment on lines 160 to 169
sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat(
[sample_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat(
[sample_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat(
[sample_lora_weight]
)
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important to take note of this conversion.

The original repository calculates QKV (3 * hidden_size) + MLP (4 * hidden_size) with a single nn.Linear and later splits them. With LoRA on this nn.Linear, it is not trivial to remap the lora_A and lora_B layers because of the different dimensions. lora_A has shape rank, hidden_size, while lora_B has shape 7 * hidden_size, rank. Splitting lora_B is trivial, but lora_A needs to be the same for all QKV and MLP.

I think it is mathematically equivalent but if someone could give another look, it would be super helpful!


While writing this comment, I just realized that I made a mistake here and did not account for the MLP lora. Will add immediatedly

Edit: Not a mistake. Only the single transformer blocks need the change I described for MLP. Double blocks have a separate MLP layer

Comment on lines 305 to 317
if lora_key == "lora_A":
lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([lora_weight])

if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys():
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([lora_bias])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above applies here

Comment on lines 1792 to 1794
# Flux Control LoRAs also have norm keys
supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys)
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For supporting the additional norm layers. Also FYI, the norm layers from the LoRA are the exact same numerically to Flux1-Canny-Dev and Flux1-Depth-Dev, but different from Flux1-Dev (the model for which the lora is intended), so we cannot do without this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Thanks for confirming!

Comment on lines 1799 to 1820
def prune_state_dict_(state_dict):
pruned_keys = []
for key in list(state_dict.keys()):
is_lora_key_present = "lora" in key
is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys)
if not is_lora_key_present and not is_norm_key_present:
state_dict.pop(key)
pruned_keys.append(key)
return pruned_keys

pruned_keys = prune_state_dict_(state_dict)
if len(pruned_keys) > 0:
logger.warning(
f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}"
)

transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k}
transformer_norm_state_dict = {
k: v
for k, v in state_dict.items()
if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys)
}
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this is self-explanatory. load_lora_adapter is incompatible for anything without lora keys, so we separate the state dict into the norm and lora dicts.

We also remove any other layers (this was just for sanity checking that I was doing things correctly) if they are incompatible while raising a warning if there are any additional keys (there are none at the moment, but good to have IMO).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!

Should we instead pop from the state_dict and raise a warning and error if there's anything remaining inside it? T

Comment on lines 219 to 221
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, we don't train the lora_B bias, but the control loras have trained biases. In this case, the assumption of being able to index val.shape[1] is incorrect since bias is a ndim=1 tensor

@yiyixuxu yiyixuxu mentioned this pull request Nov 22, 2024
3 tasks
@asfiyab-nvidia
Copy link
Contributor

Does this MR also add support for the newly released controlnet models like https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev?

@a-r-r-o-w
Copy link
Member Author

@asfiyab-nvidia Yes, this PR is to add support for all the 6 models that were released a few hours ago

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 22, 2024

@sayakpaul @yiyixuxu @DN6 The Control LoRA's require some custom logic for working. I've pushed all the changes I needed to make it work except one (which was something hardcoded that we won't be able to ship). Without this change, you can't get the lora weights to load so it is quite important.

To explain what the problem is, firstly note that there are two different layers that we call "proj_out":

When we parse the lora_B layer shapes and try to figure out the LoRA rank dict, we find the expected lora rank values. This is the code that does it. Now, because the ranks for the proj_out layer is different, when we try to create the peft lora init config, we get something like:

LoRA init config
{'r': 128, 'lora_alpha': 128, 'rank_pattern': {'proj_out': 64}, 'alpha_pattern': {}, 'target_modules': ['single_transformer_blocks.12.proj_out', 'single_transformer_blocks.16.norm.linear', 'single_transformer_blocks.8.attn.to_q', 'transformer_blocks.3.ff.net.0.proj', 'transformer_blocks.0.ff_context.net.0.proj', 'single_transformer_blocks.14.attn.to_q', 'transformer_blocks.1.ff.net.0.proj', 'single_transformer_blocks.33.proj_mlp', 'transformer_blocks.8.ff.net.0.proj', 'single_transformer_blocks.14.proj_out', 'transformer_blocks.10.ff_context.net.0.proj', 'single_transformer_blocks.22.norm.linear', 'single_transformer_blocks.35.attn.to_k', 'single_transformer_blocks.2.norm.linear', 'transformer_blocks.10.ff.net.2', 'transformer_blocks.11.ff.net.0.projlora_B..bias', 'transformer_blocks.10.norm1.linear', 'transformer_blocks.5.ff.net.0.projlora_B..bias', 'transformer_blocks.10.attn.add_v_proj', 'single_transformer_blocks.31.proj_out', 'transformer_blocks.5.attn.add_q_proj', 'single_transformer_blocks.18.attn.to_v', 'single_transformer_blocks.32.attn.to_q', 'single_transformer_blocks.7.attn.to_q', 'single_transformer_blocks.8.proj_out', 'transformer_blocks.1.attn.to_k', 'single_transformer_blocks.15.proj_mlp', 'time_text_embed.timestep_embedder.linear_2', 'transformer_blocks.2.attn.add_v_proj', 'transformer_blocks.11.attn.add_k_proj', 'single_transformer_blocks.5.proj_out', 'single_transformer_blocks.2.attn.to_q', 'transformer_blocks.9.ff.net.0.proj', 'single_transformer_blocks.18.proj_out', 'single_transformer_blocks.6.proj_mlp', 'single_transformer_blocks.12.proj_mlp', 'transformer_blocks.0.ff.net.0.projlora_B..bias', 'single_transformer_blocks.18.norm.linear', 'transformer_blocks.5.attn.to_q', 'single_transformer_blocks.25.proj_out', 'transformer_blocks.17.attn.to_v', 'transformer_blocks.16.attn.add_k_proj', 'single_transformer_blocks.34.attn.to_q', 'transformer_blocks.15.attn.to_k', 'transformer_blocks.15.ff.net.0.projlora_B..bias', 'transformer_blocks.3.attn.to_add_out', 'transformer_blocks.5.norm1_context.linear', 'single_transformer_blocks.9.norm.linear', 'transformer_blocks.3.attn.add_q_proj', 'transformer_blocks.12.attn.to_out.0', 'single_transformer_blocks.15.attn.to_k', 'transformer_blocks.10.norm1_context.linear', 'transformer_blocks.4.attn.to_q', 'transformer_blocks.13.attn.to_out.0', 'transformer_blocks.5.ff_context.net.2', 'transformer_blocks.17.attn.add_q_proj', 'proj_out', 'transformer_blocks.18.norm1_context.linear', 'single_transformer_blocks.5.attn.to_q', 'transformer_blocks.13.norm1_context.linear', 'transformer_blocks.16.attn.to_v', 'transformer_blocks.1.attn.to_q', 'transformer_blocks.15.attn.add_q_proj', 'single_transformer_blocks.19.proj_mlp', 'single_transformer_blocks.35.proj_mlp', 'single_transformer_blocks.17.norm.linear', 'single_transformer_blocks.31.attn.to_v', 'single_transformer_blocks.35.norm.linear', 'single_transformer_blocks.17.attn.to_k', 'single_transformer_blocks.2.attn.to_k', 'transformer_blocks.7.attn.to_v', 'single_transformer_blocks.24.proj_out', 'single_transformer_blocks.26.attn.to_q', 'single_transformer_blocks.15.norm.linear', 'transformer_blocks.0.norm1_context.linear', 'transformer_blocks.9.attn.add_q_proj', 'single_transformer_blocks.32.attn.to_v', 'transformer_blocks.5.attn.add_v_proj', 'single_transformer_blocks.28.proj_out', 'transformer_blocks.3.norm1.linear', 'transformer_blocks.4.ff.net.2', 'single_transformer_blocks.22.proj_mlp', 'transformer_blocks.4.attn.to_v', 'transformer_blocks.14.attn.to_add_out', 'transformer_blocks.9.attn.to_add_out', 'single_transformer_blocks.27.proj_mlp', 'transformer_blocks.3.attn.to_k', 'transformer_blocks.2.attn.to_v', 'single_transformer_blocks.21.norm.linear', 'single_transformer_blocks.5.attn.to_v', 'single_transformer_blocks.29.proj_out', 'transformer_blocks.1.attn.add_k_proj', 'transformer_blocks.1.norm1.linear', 'single_transformer_blocks.34.proj_out', 'single_transformer_blocks.16.attn.to_q', 'transformer_blocks.1.ff.net.0.projlora_B..bias', 'single_transformer_blocks.37.norm.linear', 'single_transformer_blocks.23.proj_mlp', 'single_transformer_blocks.10.proj_out', 'transformer_blocks.5.attn.to_out.0', 'single_transformer_blocks.30.proj_mlp', 'single_transformer_blocks.22.proj_out', 'single_transformer_blocks.31.proj_mlp', 'single_transformer_blocks.26.attn.to_v', 'single_transformer_blocks.0.attn.to_q', 'single_transformer_blocks.6.norm.linear', 'transformer_blocks.13.attn.add_k_proj', 'transformer_blocks.7.norm1_context.linear', 'single_transformer_blocks.4.attn.to_k', 'single_transformer_blocks.8.attn.to_v', 'single_transformer_blocks.34.proj_mlp', 'transformer_blocks.4.attn.add_k_proj', 'single_transformer_blocks.1.attn.to_v', 'transformer_blocks.14.attn.to_q', 'single_transformer_blocks.11.attn.to_q', 'single_transformer_blocks.13.attn.to_k', 'single_transformer_blocks.6.attn.to_k', 'transformer_blocks.17.ff_context.net.0.proj', 'transformer_blocks.0.ff.net.0.proj', 'single_transformer_blocks.27.attn.to_k', 'transformer_blocks.6.norm1_context.linear', 'single_transformer_blocks.1.attn.to_q', 'single_transformer_blocks.27.proj_out', 'single_transformer_blocks.24.attn.to_v', 'time_text_embed.text_embedder.linear_1', 'transformer_blocks.12.attn.add_k_proj', 'transformer_blocks.16.attn.add_q_proj', 'transformer_blocks.7.ff.net.0.proj', 'transformer_blocks.17.attn.add_v_proj', 'transformer_blocks.0.attn.to_add_out', 'transformer_blocks.14.ff_context.net.0.proj', 'transformer_blocks.14.attn.to_k', 'single_transformer_blocks.24.proj_mlp', 'transformer_blocks.16.attn.to_out.0', 'single_transformer_blocks.34.attn.to_k', 'single_transformer_blocks.0.proj_out', 'single_transformer_blocks.25.attn.to_q', 'single_transformer_blocks.35.attn.to_q', 'transformer_blocks.18.attn.to_out.0', 'transformer_blocks.3.ff_context.net.0.proj', 'single_transformer_blocks.10.attn.to_q', 'transformer_blocks.18.attn.to_add_out', 'single_transformer_blocks.33.attn.to_v', 'single_transformer_blocks.37.attn.to_k', 'transformer_blocks.6.attn.to_q', 'transformer_blocks.8.ff.net.2', 'single_transformer_blocks.24.attn.to_q', 'single_transformer_blocks.20.attn.to_q', 'single_transformer_blocks.21.attn.to_v', 'transformer_blocks.16.ff.net.0.proj', 'transformer_blocks.0.attn.add_k_proj', 'transformer_blocks.0.attn.to_k', 'time_text_embed.guidance_embedder.linear_1', 'transformer_blocks.15.ff.net.0.proj', 'transformer_blocks.3.attn.to_out.0', 'transformer_blocks.9.attn.to_q', 'single_transformer_blocks.5.proj_mlp', 'transformer_blocks.14.attn.to_v', 'transformer_blocks.1.attn.add_q_proj', 'transformer_blocks.3.attn.to_q', 'single_transformer_blocks.27.attn.to_q', 'transformer_blocks.2.norm1.linear', 'single_transformer_blocks.36.proj_out', 'single_transformer_blocks.6.proj_out', 'single_transformer_blocks.14.proj_mlp', 'transformer_blocks.0.attn.add_q_proj', 'transformer_blocks.15.attn.to_out.0', 'single_transformer_blocks.23.attn.to_q', 'transformer_blocks.7.ff.net.2', 'single_transformer_blocks.5.attn.to_k', 'transformer_blocks.11.attn.to_q', 'transformer_blocks.13.attn.to_v', 'single_transformer_blocks.7.attn.to_k', 'single_transformer_blocks.30.attn.to_v', 'single_transformer_blocks.11.proj_out', 'transformer_blocks.13.attn.add_v_proj', 'single_transformer_blocks.23.attn.to_v', 'transformer_blocks.2.norm1_context.linear', 'transformer_blocks.1.attn.to_v', 'transformer_blocks.10.attn.add_k_proj', 'single_transformer_blocks.17.proj_out', 'transformer_blocks.10.attn.to_q', 'transformer_blocks.13.ff.net.2', 'single_transformer_blocks.11.norm.linear', 'transformer_blocks.17.norm1.linear', 'single_transformer_blocks.9.proj_mlp', 'single_transformer_blocks.9.proj_out', 'single_transformer_blocks.20.proj_mlp', 'transformer_blocks.12.norm1.linear', 'transformer_blocks.9.attn.to_out.0', 'transformer_blocks.8.attn.to_out.0', 'single_transformer_blocks.14.attn.to_k', 'single_transformer_blocks.2.proj_mlp', 'transformer_blocks.9.norm1.linear', 'transformer_blocks.7.norm1.linear', 'transformer_blocks.12.ff.net.0.projlora_B..bias', 'single_transformer_blocks.17.attn.to_q', 'single_transformer_blocks.11.proj_mlp', 'transformer_blocks.4.ff.net.0.proj', 'transformer_blocks.8.norm1.linear', 'single_transformer_blocks.4.proj_mlp', 'single_transformer_blocks.32.proj_out', 'transformer_blocks.10.attn.to_k', 'transformer_blocks.3.ff.net.2', 'single_transformer_blocks.19.attn.to_k', 'single_transformer_blocks.12.attn.to_k', 'single_transformer_blocks.28.attn.to_k', 'single_transformer_blocks.6.attn.to_v', 'transformer_blocks.1.ff.net.2', 'single_transformer_blocks.28.proj_mlp', 'transformer_blocks.5.ff_context.net.0.proj', 'single_transformer_blocks.23.proj_out', 'single_transformer_blocks.30.attn.to_q', 'transformer_blocks.13.ff_context.net.2', 'single_transformer_blocks.36.norm.linear', 'transformer_blocks.16.norm1.linear', 'transformer_blocks.8.attn.add_k_proj', 'single_transformer_blocks.5.norm.linear', 'transformer_blocks.8.ff_context.net.2', 'transformer_blocks.10.ff.net.0.proj', 'transformer_blocks.15.ff_context.net.2', 'single_transformer_blocks.19.attn.to_v', 'single_transformer_blocks.29.norm.linear', 'transformer_blocks.16.ff_context.net.0.proj', 'transformer_blocks.3.norm1_context.linear', 'transformer_blocks.13.attn.add_q_proj', 'transformer_blocks.8.attn.to_k', 'single_transformer_blocks.8.norm.linear', 'single_transformer_blocks.15.attn.to_v', 'transformer_blocks.8.attn.add_q_proj', 'transformer_blocks.15.norm1_context.linear', 'transformer_blocks.7.attn.to_out.0', 'transformer_blocks.4.norm1.linear', 'single_transformer_blocks.4.attn.to_q', 'transformer_blocks.16.attn.add_v_proj', 'transformer_blocks.14.attn.to_out.0', 'transformer_blocks.15.attn.add_k_proj', 'transformer_blocks.9.attn.to_v', 'single_transformer_blocks.10.proj_mlp', 'transformer_blocks.16.ff_context.net.2', 'transformer_blocks.11.attn.add_q_proj', 'transformer_blocks.17.ff_context.net.2', 'single_transformer_blocks.14.attn.to_v', 'single_transformer_blocks.0.attn.to_k', 'transformer_blocks.18.ff_context.net.0.proj', 'single_transformer_blocks.34.attn.to_v', 'transformer_blocks.1.ff_context.net.0.proj', 'transformer_blocks.11.ff_context.net.0.proj', 'single_transformer_blocks.15.attn.to_q', 'transformer_blocks.11.attn.to_v', 'transformer_blocks.9.norm1_context.linear', 'single_transformer_blocks.16.proj_mlp', 'single_transformer_blocks.33.norm.linear', 'transformer_blocks.15.attn.to_q', 'transformer_blocks.7.attn.add_q_proj', 'transformer_blocks.7.attn.to_q', 'transformer_blocks.5.attn.add_k_proj', 'single_transformer_blocks.20.proj_out', 'single_transformer_blocks.37.proj_out', 'single_transformer_blocks.17.proj_mlp', 'transformer_blocks.18.ff.net.2', 'transformer_blocks.0.norm1.linear', 'single_transformer_blocks.13.attn.to_q', 'transformer_blocks.8.ff_context.net.0.proj', 'single_transformer_blocks.28.attn.to_v', 'single_transformer_blocks.1.attn.to_k', 'transformer_blocks.10.attn.add_q_proj', 'transformer_blocks.18.attn.add_v_proj', 'transformer_blocks.11.ff.net.0.proj', 'single_transformer_blocks.33.attn.to_k', 'transformer_blocks.17.attn.add_k_proj', 'transformer_blocks.7.ff_context.net.2', 'single_transformer_blocks.25.attn.to_v', 'transformer_blocks.0.ff_context.net.2', 'transformer_blocks.11.norm1.linear', 'transformer_blocks.17.attn.to_add_out', 'transformer_blocks.6.attn.to_add_out', 'transformer_blocks.12.ff.net.0.proj', 'single_transformer_blocks.15.proj_out', 'transformer_blocks.0.attn.to_out.0', 'transformer_blocks.4.attn.to_out.0', 'transformer_blocks.8.attn.add_v_proj', 'transformer_blocks.8.ff.net.0.projlora_B..bias', 'transformer_blocks.14.norm1_context.linear', 'time_text_embed.text_embedder.linear_2', 'transformer_blocks.12.attn.to_q', 'single_transformer_blocks.4.norm.linear', 'transformer_blocks.18.ff.net.0.proj', 'transformer_blocks.18.attn.to_k', 'single_transformer_blocks.20.norm.linear', 'transformer_blocks.1.attn.add_v_proj', 'single_transformer_blocks.3.proj_mlp', 'single_transformer_blocks.26.proj_mlp', 'single_transformer_blocks.30.attn.to_k', 'x_embedder', 'transformer_blocks.16.attn.to_add_out', 'single_transformer_blocks.35.proj_out', 'single_transformer_blocks.9.attn.to_q', 'transformer_blocks.18.norm1.linear', 'single_transformer_blocks.32.proj_mlp', 'transformer_blocks.15.attn.to_add_out', 'single_transformer_blocks.24.norm.linear', 'transformer_blocks.4.attn.add_q_proj', 'single_transformer_blocks.10.norm.linear', 'single_transformer_blocks.22.attn.to_k', 'single_transformer_blocks.16.attn.to_v', 'transformer_blocks.11.attn.add_v_proj', 'single_transformer_blocks.7.norm.linear', 'single_transformer_blocks.2.attn.to_v', 'transformer_blocks.5.ff.net.0.proj', 'single_transformer_blocks.36.attn.to_q', 'single_transformer_blocks.31.norm.linear', 'single_transformer_blocks.10.attn.to_v', 'transformer_blocks.13.ff.net.0.projlora_B..bias', 'single_transformer_blocks.37.proj_mlp', 'single_transformer_blocks.29.attn.to_k', 'single_transformer_blocks.16.attn.to_k', 'single_transformer_blocks.19.proj_out', 'transformer_blocks.2.ff_context.net.0.proj', 'transformer_blocks.6.ff.net.2', 'transformer_blocks.2.ff_context.net.2', 'single_transformer_blocks.1.proj_out', 'transformer_blocks.2.ff.net.0.projlora_B..bias', 'transformer_blocks.5.attn.to_k', 'transformer_blocks.10.ff_context.net.2', 'transformer_blocks.10.attn.to_v', 'single_transformer_blocks.6.attn.to_q', 'transformer_blocks.13.attn.to_add_out', 'transformer_blocks.7.attn.to_add_out', 'single_transformer_blocks.1.proj_mlp', 'transformer_blocks.6.attn.to_v', 'transformer_blocks.8.attn.to_q', 'transformer_blocks.5.attn.to_add_out', 'transformer_blocks.16.ff.net.2', 'transformer_blocks.0.attn.to_v', 'context_embedder', 'single_transformer_blocks.25.attn.to_k', 'single_transformer_blocks.33.attn.to_q', 'single_transformer_blocks.37.attn.to_q', 'transformer_blocks.14.norm1.linear', 'transformer_blocks.6.ff_context.net.0.proj', 'transformer_blocks.16.attn.to_q', 'transformer_blocks.3.ff_context.net.2', 'transformer_blocks.2.attn.to_k', 'transformer_blocks.14.ff.net.0.projlora_B..bias', 'transformer_blocks.15.ff_context.net.0.proj', 'transformer_blocks.0.attn.add_v_proj', 'transformer_blocks.17.attn.to_out.0', 'single_transformer_blocks.36.attn.to_k', 'transformer_blocks.6.attn.add_q_proj', 'single_transformer_blocks.29.proj_mlp', 'single_transformer_blocks.23.attn.to_k', 'transformer_blocks.12.norm1_context.linear', 'time_text_embed.guidance_embedder.linear_2', 'transformer_blocks.1.attn.to_out.0', 'transformer_blocks.4.ff_context.net.0.proj', 'transformer_blocks.7.attn.add_k_proj', 'transformer_blocks.17.ff.net.2', 'single_transformer_blocks.17.attn.to_v', 'transformer_blocks.11.norm1_context.linear', 'single_transformer_blocks.33.proj_out', 'single_transformer_blocks.3.attn.to_q', 'transformer_blocks.9.attn.add_k_proj', 'transformer_blocks.3.ff.net.0.projlora_B..bias', 'transformer_blocks.10.ff.net.0.projlora_B..bias', 'single_transformer_blocks.4.proj_out', 'single_transformer_blocks.31.attn.to_k', 'transformer_blocks.12.attn.to_k', 'transformer_blocks.6.ff.net.0.projlora_B..bias', 'single_transformer_blocks.21.proj_out', 'single_transformer_blocks.12.attn.to_q', 'transformer_blocks.14.attn.add_v_proj', 'transformer_blocks.4.ff.net.0.projlora_B..bias', 'transformer_blocks.13.ff.net.0.proj', 'transformer_blocks.10.attn.to_add_out', 'single_transformer_blocks.8.proj_mlp', 'transformer_blocks.17.ff.net.0.projlora_B..bias', 'transformer_blocks.18.attn.add_k_proj', 'transformer_blocks.4.attn.add_v_proj', 'transformer_blocks.9.attn.add_v_proj', 'single_transformer_blocks.3.proj_out', 'transformer_blocks.15.ff.net.2', 'single_transformer_blocks.26.attn.to_k', 'transformer_blocks.17.attn.to_k', 'single_transformer_blocks.14.norm.linear', 'transformer_blocks.13.attn.to_q', 'single_transformer_blocks.0.norm.linear', 'single_transformer_blocks.11.attn.to_v', 'transformer_blocks.12.ff_context.net.2', 'transformer_blocks.14.ff.net.0.proj', 'single_transformer_blocks.7.attn.to_v', 'single_transformer_blocks.3.attn.to_k', 'transformer_blocks.7.ff_context.net.0.proj', 'single_transformer_blocks.27.norm.linear', 'single_transformer_blocks.4.attn.to_v', 'transformer_blocks.3.attn.to_v', 'single_transformer_blocks.32.norm.linear', 'single_transformer_blocks.20.attn.to_v', 'transformer_blocks.9.ff.net.2', 'transformer_blocks.16.norm1_context.linear', 'single_transformer_blocks.0.attn.to_v', 'single_transformer_blocks.2.proj_out', 'single_transformer_blocks.19.attn.to_q', 'transformer_blocks.10.attn.to_out.0', 'single_transformer_blocks.26.norm.linear', 'single_transformer_blocks.0.proj_mlp', 'transformer_blocks.6.norm1.linear', 'single_transformer_blocks.21.proj_mlp', 'single_transformer_blocks.36.attn.to_v', 'transformer_blocks.18.ff_context.net.2', 'transformer_blocks.4.ff_context.net.2', 'transformer_blocks.9.ff.net.0.projlora_B..bias', 'transformer_blocks.5.ff.net.2', 'transformer_blocks.6.ff_context.net.2', 'transformer_blocks.4.attn.to_k', 'transformer_blocks.5.norm1.linear', 'single_transformer_blocks.7.proj_mlp', 'transformer_blocks.18.attn.to_q', 'transformer_blocks.11.attn.to_add_out', 'transformer_blocks.7.attn.add_v_proj', 'single_transformer_blocks.16.proj_out', 'transformer_blocks.18.attn.add_q_proj', 'transformer_blocks.0.attn.to_q', 'transformer_blocks.5.attn.to_v', 'transformer_blocks.6.attn.to_out.0', 'transformer_blocks.11.attn.to_out.0', 'single_transformer_blocks.21.attn.to_k', 'single_transformer_blocks.25.proj_mlp', 'single_transformer_blocks.9.attn.to_v', 'single_transformer_blocks.13.proj_out', 'single_transformer_blocks.13.norm.linear', 'transformer_blocks.2.attn.to_out.0', 'single_transformer_blocks.26.proj_out', 'single_transformer_blocks.19.norm.linear', 'transformer_blocks.15.attn.to_v', 'transformer_blocks.4.attn.to_add_out', 'transformer_blocks.18.ff.net.0.projlora_B..bias', 'transformer_blocks.0.ff.net.2', 'transformer_blocks.12.ff_context.net.0.proj', 'single_transformer_blocks.31.attn.to_q', 'transformer_blocks.8.attn.to_add_out', 'single_transformer_blocks.27.attn.to_v', 'transformer_blocks.14.ff_context.net.2', 'single_transformer_blocks.18.proj_mlp', 'transformer_blocks.13.attn.to_k', 'transformer_blocks.2.attn.to_add_out', 'single_transformer_blocks.13.attn.to_v', 'single_transformer_blocks.18.attn.to_q', 'single_transformer_blocks.23.norm.linear', 'single_transformer_blocks.32.attn.to_k', 'transformer_blocks.1.norm1_context.linear', 'norm_out.linear', 'transformer_blocks.9.ff_context.net.0.proj', 'transformer_blocks.2.attn.add_k_proj', 'transformer_blocks.16.ff.net.0.projlora_B..bias', 'transformer_blocks.6.attn.add_k_proj', 'transformer_blocks.2.attn.to_q', 'single_transformer_blocks.28.attn.to_q', 'single_transformer_blocks.35.attn.to_v', 'transformer_blocks.12.ff.net.2', 'transformer_blocks.11.ff.net.2', 'transformer_blocks.14.ff.net.2', 'transformer_blocks.17.norm1_context.linear', 'transformer_blocks.12.attn.add_v_proj', 'single_transformer_blocks.1.norm.linear', 'transformer_blocks.15.attn.add_v_proj', 'transformer_blocks.12.attn.to_v', 'transformer_blocks.15.norm1.linear', 'transformer_blocks.6.attn.add_v_proj', 'single_transformer_blocks.21.attn.to_q', 'transformer_blocks.3.attn.add_v_proj', 'single_transformer_blocks.22.attn.to_v', 'transformer_blocks.17.ff.net.0.proj', 'single_transformer_blocks.7.proj_out', 'single_transformer_blocks.12.norm.linear', 'single_transformer_blocks.29.attn.to_v', 'single_transformer_blocks.3.norm.linear', 'transformer_blocks.12.attn.add_q_proj', 'single_transformer_blocks.18.attn.to_k', 'time_text_embed.timestep_embedder.linear_1', 'single_transformer_blocks.11.attn.to_k', 'transformer_blocks.17.attn.to_q', 'transformer_blocks.4.norm1_context.linear', 'transformer_blocks.11.attn.to_k', 'single_transformer_blocks.25.norm.linear', 'transformer_blocks.14.attn.add_q_proj', 'single_transformer_blocks.29.attn.to_q', 'single_transformer_blocks.22.attn.to_q', 'single_transformer_blocks.30.norm.linear', 'transformer_blocks.8.norm1_context.linear', 'transformer_blocks.7.ff.net.0.projlora_B..bias', 'transformer_blocks.9.attn.to_k', 'transformer_blocks.16.attn.to_k', 'transformer_blocks.2.attn.add_q_proj', 'single_transformer_blocks.9.attn.to_k', 'transformer_blocks.3.attn.add_k_proj', 'transformer_blocks.18.attn.to_v', 'single_transformer_blocks.34.norm.linear', 'transformer_blocks.2.ff.net.0.proj', 'single_transformer_blocks.37.attn.to_v', 'transformer_blocks.6.ff.net.0.proj', 'transformer_blocks.9.ff_context.net.2', 'single_transformer_blocks.3.attn.to_v', 'transformer_blocks.8.attn.to_v', 'single_transformer_blocks.12.attn.to_v', 'transformer_blocks.12.attn.to_add_out', 'transformer_blocks.13.norm1.linear', 'single_transformer_blocks.13.proj_mlp', 'transformer_blocks.1.attn.to_add_out', 'single_transformer_blocks.8.attn.to_k', 'transformer_blocks.13.ff_context.net.0.proj', 'transformer_blocks.7.attn.to_k', 'single_transformer_blocks.36.proj_mlp', 'transformer_blocks.1.ff_context.net.2', 'single_transformer_blocks.10.attn.to_k', 'single_transformer_blocks.28.norm.linear', 'single_transformer_blocks.30.proj_out', 'transformer_blocks.2.ff.net.2', 'transformer_blocks.6.attn.to_k', 'transformer_blocks.11.ff_context.net.2', 'transformer_blocks.14.attn.add_k_proj', 'single_transformer_blocks.24.attn.to_k', 'single_transformer_blocks.20.attn.to_k'], 'use_dora': False}

As you can see, the rank is correctly set to 128 (most layers in the control loras have rank=128), but the rank_pattern is incorrectly set to proj_out: 64 (only the final proj_out has rank 64, but all the other proj_out layers in single transformer blocks need rank=128).

Since PEFT works with regex strings, proj_out ends up settings rank=64 loras for all proj_out layers despite the ones in single_transformer_blocks needing to be 128. This seems like a long standing bug because you can't use different ranks on layers with the same name without this method breaking (but again, it is very rare that this would be the case). I haven't poked around in the peft code enough to be able to make changes without potentially breaking backwards compatibility when trying to fix this, so I'm looking for suggestions on what to do, or if one of the more experience people would like to take this up.


I've tried my best to ensure that no backwards compatibility is broken with the current changes in the LoRA, but I think we might have to do some additional testing to ensure everything is okay

@a-r-r-o-w a-r-r-o-w requested review from DN6 and sayakpaul November 22, 2024 06:45
@sayakpaul
Copy link
Member

Since PEFT works with regex strings, proj_out ends up settings rank=64 loras for all proj_out layers despite the ones in single_transformer_blocks needing to be 128. This seems like a long standing bug because you can't use different ranks on layers with the same name without this method breaking (but again, it is very rare that this would be the case). I haven't poked around in the peft code enough to be able to make changes without potentially breaking backwards compatibility when trying to fix this, so I'm looking for suggestions on what to do, or if one of the more experience people would like to take this up.

First of all, thanks for all your hard work here! This indeed seems like a problem. Minimal reproducer:

Code
from peft import LoraConfig, get_peft_model
import torch 
import re

class DummyBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.proj_out = torch.nn.Linear(10, 10)
    
    def forward(self, x):
        return self.proj_out(x)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = torch.nn.ModuleList([DummyBlock() for _ in range(3)])
        self.intermediate = torch.nn.Linear(10, 10)
        self.proj_out = torch.nn.Linear(10, 3)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.intermediate(x)
        return self.proj_out(x)

model = MyModel()
x = torch.randn(10, 10)
out = model(x)
print(model.state_dict().keys())
print(out.shape)

config = LoraConfig(
    r=2, 
    target_modules=["intermediate", "proj_out"], 
    rank_pattern={
        "proj_out": 4,
        "blocks.*": 3
    }
)
model = get_peft_model(model, config)

for name, p in model.named_parameters():
    if "lora" in name:
        print(f"{name}: {p.data.shape}")

Print:

base_model.model.blocks.0.proj_out.lora_A.default.weight: torch.Size([4, 10])
base_model.model.blocks.0.proj_out.lora_B.default.weight: torch.Size([10, 4])
base_model.model.blocks.1.proj_out.lora_A.default.weight: torch.Size([4, 10])
base_model.model.blocks.1.proj_out.lora_B.default.weight: torch.Size([10, 4])
base_model.model.blocks.2.proj_out.lora_A.default.weight: torch.Size([4, 10])
base_model.model.blocks.2.proj_out.lora_B.default.weight: torch.Size([10, 4])
base_model.model.intermediate.lora_A.default.weight: torch.Size([2, 10])
base_model.model.intermediate.lora_B.default.weight: torch.Size([10, 2])
base_model.model.proj_out.lora_A.default.weight: torch.Size([4, 10])
base_model.model.proj_out.lora_B.default.weight: torch.Size([3, 4])

Would expect the blocks prefixed keys to have a rank of 3.

@BenjaminBossan thoughts?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on LoRA. I have left a couple of comments.

Some other comments:

  • How are we going to load the converted LoRA state dict into a transformer? Which transformer checkpoint should be used? In the original codebase, they use the Dev checkpoint. See here and here. Confirmed after printing as well. So, if we were to use the pre-trained Flux.1 Dev checkpoint, the input linear layer needs to be rejigged to have a dimensionality of (3072, 128) (instead of (3072, 128)) (reference) and the state dict needs to be expanded like so.
  • The norm scales should not be pre-trained Flux.1 Dev ones and the ones supplied in the LoRA state dict, instead (as they differ in magnitude and we can confirm that).

LMK if anything is unclear or if I missed something.

@@ -0,0 +1,393 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do the conversion on the fly like we do for the other LoRA checkpoints. Why are we doing this here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move the convert code to do it on-the-fly after we figure out the changes to make it work, no?

converted_state_dict[
f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.weight"
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could keep original_state_dict.keys() into a variable and reuse that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

converted_state_dict = {}

## time_text_embed.timestep_embedder <- time_in
for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoRA keys are essentially the same no? Why separate them out?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I had the diffusers_lora_key as [lora_down, lora_up] before this. Will update

Comment on lines 128 to 130
converted_state_dict[f"x_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop(
f"img_in.{lora_key}.weight"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Control model has doubled the input dimension from 64 to 128. How are we accounting for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do it for the state dict conversion. The underlying lora shapes remain same as we are only performing renaming of layers here

original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
)

print("Remaining:", original_state_dict.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should not be any remaining keys. I am aware that there are non-LoRA keys inside the original Control LoRA state dicts and they are all norm scales and they are different from the Flux.1 Dev pre-trained ones. So, we need to account for them as well.

If there are any remaining keys, we should error out like we do for other non-diffusers LoRA checkpoints.

Refer to https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prints an empty list and was just for sanity checking. Conversion is working as expected

@a-r-r-o-w a-r-r-o-w changed the title New Flux Fill, ControlNet, Control LoRA, Redux Flux Fill, Canny, Depth, Redux Nov 22, 2024
@a-r-r-o-w a-r-r-o-w mentioned this pull request Nov 23, 2024
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 8

control_image = self.prepare_image(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that for flux, pre-generated latents are packed into 3d tensor, so won't be recognized by image processor, we can refactor the image processor in a follow-up PR so that all image input can contain the latent form

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, passing latents to control_image works as expected (I've verified numerically with original implementation as well). It is thanks to line 576 (https://github.com/huggingface/diffusers/pull/9985/files#diff-8f811cfa985866e8aeece00b90679d52566e97767d6ba74f8974be8569b60ee5R576) where we simply return if it's already a torch tensor.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And line 777 takes care of the normal image tensor input case, because control_image is expected to be a ndim=3 tensor (already packed) if it is of latent form

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I missed that, but it is still wrong (even though it works for latents input) , because we also allow pass the image (4d) as a torch tensor, which needs to be put through image processor for some operations such as normalization etc - but it is not a big deal, I think most of the time people pass PIL Image. We will need to clean it up across the pipeline though

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will merge soon!

@yiyixuxu yiyixuxu merged commit 7ac6e28 into main Nov 23, 2024
18 checks passed
@yiyixuxu yiyixuxu deleted the flux-new branch November 23, 2024 11:41
@MohamedAliRashad
Copy link

I have been testing this PR and there are couple of thing s i want to point out:

  1. FluxPriorReduxPipeline has an encode_prompt method that requires both prompt and prompt_2 to work while in the documentation it says that prompt_2 is optional and will be overwritten with prompt if not avialable.
  2. There is no function to import text_encoder/tokenizer from base flux to Redux pipeline:
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Redux-dev",
    revision="refs/pr/8",
    torch_dtype=torch.bfloat16,
)

This gives an output of this:

FluxPriorReduxPipeline {
  "_class_name": "FluxPriorReduxPipeline",
  "_diffusers_version": "0.32.0.dev0",
  "_name_or_path": "black-forest-labs/FLUX.1-Redux-dev",
  "feature_extractor": [
    "transformers",
    "SiglipImageProcessor"
  ],
  "image_embedder": [
    "flux",
    "ReduxImageEncoder"
  ],
  "image_encoder": [
    "transformers",
    "SiglipVisionModel"
  ],
  "text_encoder": [
    null,
    null
  ],
  "text_encoder_2": [
    null,
    null
  ],
  "tokenizer": [
    null,
    null
  ],
  "tokenizer_2": [
    null,
    null
  ]
}

I want a way to load text_encoder and tokenizers into it so i can use text prompt in guiding the output

@yiyixuxu
Copy link
Collaborator

@MohamedAliRashad
hi! see more details about redux here #9988

It is an image variation pipeline, and cannot be guided by text prompt; However if you want to experiment with prompt, you can follow these two steps:

  1. you will need to modify the pipeline to take a prompt input, or hardcode the "" string here
  2. load the text encoders into the redux pipeline like this
import torch
from PIL import Image
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Redux-dev", 
    text_encoder=pipe.text_encoder,
    tokenizer=pipe.tokenizer,
    text_encoder_2=pipe.text_encoder_2,
    tokenizer_2=pipe.tokenizer_2,
    revision="refs/pr/8",
    torch_dtype=dtype
)
pipe_prior_redux.to(device)

img_path = "/raid/yiyi/flux-new/assets/robot.webp"
image = Image.open(img_path).convert("RGB")

pipe_prior_output = pipe_prior_redux(image)

image = pipe(
    guidance_scale=2.5,
    height=768,
    width=1360,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0),
    **pipe_prior_output,
).images[0]

@MohamedAliRashad
Copy link

@yiyixuxu
It is possible to steer the redux output with text. It is just may not be avialable for the public
Screenshot from 2024-11-25 00-11-48

@Clement-Lelievre
Copy link
Contributor

Clement-Lelievre commented Nov 28, 2024

hi,

thanks for supporting the new flux-fill bfl model.

Do you know if it allows object-removal without fine-tuning? eg only by carefully selecting the inference params? I couldn't make this idea work so far

@asomoza

thanks

@chuck-ma
Copy link

I think if flux-fill support strength, we can make it allow object-removal more easily.

@asomoza
Copy link
Member

asomoza commented Nov 29, 2024

Actually that was one of the first generations I tried, it works right out of the box for me, I just leave the prompt empty and it removes the object, probably it depends on the subject.

With the guide and space that I made for SDXL, I noticed that a lot of people tries to "remove an object" but that object is like 90% of the image which is not real and not an object removal, practically it's just a new generation, so you have to describe the complete scene and what you want for it to work.

Also, it works and it's good but not better than the SDXL and it takes a lot of VRAM to run it, it's slower and not commercial, on the other side, it's easier to use and doesn't need a custom pipeline or workflow for it to work, so now you have options for quality "fills" which is really good.

origina mask SDXL Flux
car_crop car_mask final_generation 20241129180840_3829776120

@Clement-Lelievre
Copy link
Contributor

Clement-Lelievre commented Nov 30, 2024

Actually that was one of the first generations I tried, it works right out of the box for me, I just leave the prompt empty and it removes the object, probably it depends on the subject.

With the guide and space that I made for SDXL, I noticed that a lot of people tries to "remove an object" but that object is like 90% of the image which is not real and not an object removal, practically it's just a new generation, so you have to describe the complete scene and what you want for it to work.

good to know, i'll have to try again with a negative prompt, because I tried on the demo snippet (below, given as is) for the flux-fill model with a variety of params combinations, and it didn't work. I tried erasing the cup on the image, so it's a reasonably small portion of the image.

import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
image = pipe(
    prompt="a white paper cup",
    image=image,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save(f"flux-fill-dev.png")

@asomoza
Copy link
Member

asomoza commented Nov 30, 2024

I tried it with the cup and you're right, probably because of the training, it didn't erase completely the object on the sand.

mask result 1 result 2
cup_remove_mask 20241130171709_2337159937 20241130171737_2185983841

Changing the mask to a bigger and better defined one helped but it still generated objects half of the time.

mask result 1 result 2
cup_better_removal_mask 20241130172511_1238894051 20241130172803_1114628213

Changing the prompt to "sand" got me better results and only one time it generated an object

20241130173008_327854290 20241130173037_3816256642 20241130173134_2809032137

Also the generated sand it's weird, like finer and not consistent with the rest of the sand in the image, so probably this example is one of the things this model is not good at (nice to know for future references).

For example, with SDXL it always erases the object and the sand is more consistent with the rest of the image but since it uses a lighting model it also seems kind of low quality:

image (21) image (22) image (23)

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* update

---------

Co-authored-by: yiyixuxu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants