Open
Description
Describe the bug
When i try to load a lora, such as alimama-creative/FLUX.1-Turbo-Alpha
, into nf4 quantized flux fill pipeline it gives an error
Reproduction
from diffusers import FluxPipeline,FluxPriorReduxPipeline, FluxFillPipeline, FluxTransformer2DModel
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
import torch
dtype = torch.bfloat16
quant_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype
)
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=dtype,
)
pipeline = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
transformer=transformer,
torch_dtype=dtype,
).to("cuda")
pipeline.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha")
Logs
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 pipeline.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name=f"lora_")
File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:1550, in FluxLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
1543 transformer_norm_state_dict = {
1544 k: state_dict.pop(k)
1545 for k in list(state_dict.keys())
1546 if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1547 }
1549 transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1550 has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1551 transformer, transformer_lora_state_dict, transformer_norm_state_dict
1552 )
1554 if has_param_with_expanded_shape:
1555 logger.info(
1556 "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
1557 "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1558 "To get a comprehensive list of parameter names that were modified, enable debug logging."
1559 )
File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:2020, in FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
2017 parent_module = transformer.get_submodule(parent_module_name)
2019 with torch.device("meta"):
-> 2020 expanded_module = torch.nn.Linear(
2021 in_features, out_features, bias=bias, dtype=module_weight.dtype
2022 )
2023 # Only weights are expanded and biases are not. This is because only the input dimensions
2024 # are changed while the output dimensions remain the same. The shape of the weight tensor
2025 # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2026 # explains the reason why only weights are expanded.
2027 new_weight = torch.zeros_like(
2028 expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2029 )
File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/torch/nn/modules/linear.py:105, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
103 self.in_features = in_features
104 self.out_features = out_features
--> 105 self.weight = Parameter(
106 torch.empty((out_features, in_features), **factory_kwargs)
107 )
108 if bias:
109 self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/torch/nn/parameter.py:46, in Parameter.__new__(cls, data, requires_grad)
42 data = torch.empty(0)
43 if type(data) is torch.Tensor or type(data) is Parameter:
44 # For ease of BC maintenance, keep this path for standard Tensor.
45 # Eventually (tm), we should change the behavior for standard Tensor to match.
---> 46 return torch.Tensor._make_subclass(cls, data, requires_grad)
48 # Path for custom tensors: set a flag on the instance to indicate parameter-ness.
49 t = data.detach().requires_grad_(requires_grad)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- 🤗 Diffusers version: 0.32.2
- Platform: Linux-6.8.0-1019-aws-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.0
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.27.1
- Transformers version: 4.47.1
- Accelerate version: 1.2.1
- PEFT version: 0.14.0
- Bitsandbytes version: 0.45.0
- Safetensors version: 0.5.2
- xFormers version: not installed