Skip to content

Fix missing **kwargs in lora_pipeline.py #11011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 72 additions & 24 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
Expand All @@ -472,7 +476,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -891,7 +895,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
Expand All @@ -912,7 +920,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class SD3LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -1290,7 +1298,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
Expand All @@ -1312,7 +1324,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class FluxLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -1828,7 +1840,11 @@ def fuse_lora(
)

super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
Expand All @@ -1849,7 +1865,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)

super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)

# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
Expand Down Expand Up @@ -2548,7 +2564,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
Expand All @@ -2566,7 +2586,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class Mochi1LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -2852,7 +2872,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -2871,7 +2895,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class LTXVideoLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -3157,7 +3181,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -3176,7 +3204,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class SanaLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -3462,7 +3490,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -3481,7 +3513,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -3770,7 +3802,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -3789,7 +3825,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class Lumina2LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -4079,7 +4115,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
Expand All @@ -4098,7 +4138,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class WanLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -4384,7 +4424,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -4403,7 +4447,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class CogView4LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -4689,7 +4733,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -4708,7 +4756,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
Expand Down
Loading