-
Notifications
You must be signed in to change notification settings - Fork 6k
[feat] add load_lora_adapter()
for compatible models
#9712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
07123e1
d099b84
4d307cc
984b8c9
2e70a93
c0f4585
8e8e6b1
0f6ce88
3541495
37acd79
bc74fe8
d3afa26
c28d6f3
c89c318
e204844
e187b70
431fc53
6ca6c65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,9 @@ | |
|
||
logger = logging.get_logger(__name__) | ||
|
||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" | ||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" | ||
|
||
|
||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): | ||
""" | ||
|
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder): | |
text_encoder._hf_peft_config_loaded = None | ||
|
||
|
||
def _fetch_state_dict( | ||
pretrained_model_name_or_path_or_dict, | ||
weight_name, | ||
use_safetensors, | ||
local_files_only, | ||
cache_dir, | ||
force_download, | ||
proxies, | ||
token, | ||
revision, | ||
subfolder, | ||
user_agent, | ||
allow_pickle, | ||
): | ||
model_file = None | ||
if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
# Let's first try to load .safetensors weights | ||
if (use_safetensors and weight_name is None) or ( | ||
weight_name is not None and weight_name.endswith(".safetensors") | ||
): | ||
try: | ||
# Here we're relaxing the loading check to enable more Inference API | ||
# friendliness where sometimes, it's not at all possible to automatically | ||
# determine `weight_name`. | ||
if weight_name is None: | ||
weight_name = _best_guess_weight_name( | ||
pretrained_model_name_or_path_or_dict, | ||
file_extension=".safetensors", | ||
local_files_only=local_files_only, | ||
) | ||
model_file = _get_model_file( | ||
pretrained_model_name_or_path_or_dict, | ||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
local_files_only=local_files_only, | ||
token=token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
user_agent=user_agent, | ||
) | ||
state_dict = safetensors.torch.load_file(model_file, device="cpu") | ||
except (IOError, safetensors.SafetensorError) as e: | ||
if not allow_pickle: | ||
raise e | ||
# try loading non-safetensors weights | ||
model_file = None | ||
pass | ||
|
||
if model_file is None: | ||
if weight_name is None: | ||
weight_name = _best_guess_weight_name( | ||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only | ||
) | ||
model_file = _get_model_file( | ||
pretrained_model_name_or_path_or_dict, | ||
weights_name=weight_name or LORA_WEIGHT_NAME, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
local_files_only=local_files_only, | ||
token=token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
user_agent=user_agent, | ||
) | ||
state_dict = load_state_dict(model_file) | ||
else: | ||
state_dict = pretrained_model_name_or_path_or_dict | ||
|
||
return state_dict | ||
|
||
|
||
def _best_guess_weight_name( | ||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False | ||
): | ||
if local_files_only or HF_HUB_OFFLINE: | ||
raise ValueError("When using the offline mode, you must specify a `weight_name`.") | ||
|
||
targeted_files = [] | ||
|
||
if os.path.isfile(pretrained_model_name_or_path_or_dict): | ||
return | ||
elif os.path.isdir(pretrained_model_name_or_path_or_dict): | ||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] | ||
else: | ||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings | ||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] | ||
if len(targeted_files) == 0: | ||
return | ||
|
||
# "scheduler" does not correspond to a LoRA checkpoint. | ||
# "optimizer" does not correspond to a LoRA checkpoint | ||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint". | ||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} | ||
targeted_files = list( | ||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) | ||
) | ||
|
||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): | ||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) | ||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): | ||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) | ||
|
||
if len(targeted_files) > 1: | ||
raise ValueError( | ||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." | ||
) | ||
weight_name = targeted_files[0] | ||
return weight_name | ||
|
||
|
||
class LoraBaseMixin: | ||
"""Utility class for handling LoRAs.""" | ||
|
||
|
@@ -234,124 +350,16 @@ def _optionally_disable_offloading(cls, _pipeline): | |
return (is_model_cpu_offload, is_sequential_cpu_offload) | ||
|
||
@classmethod | ||
def _fetch_state_dict( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are internal methods, so it should be okay to move them around. But would be good to run a quick Github search to see if they aren't being used directly somewhere? Just to sanity check that we don't backwards break anything. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Valid. I deprecated and added tests. |
||
cls, | ||
pretrained_model_name_or_path_or_dict, | ||
weight_name, | ||
use_safetensors, | ||
local_files_only, | ||
cache_dir, | ||
force_download, | ||
proxies, | ||
token, | ||
revision, | ||
subfolder, | ||
user_agent, | ||
allow_pickle, | ||
): | ||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE | ||
|
||
model_file = None | ||
if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
# Let's first try to load .safetensors weights | ||
if (use_safetensors and weight_name is None) or ( | ||
weight_name is not None and weight_name.endswith(".safetensors") | ||
): | ||
try: | ||
# Here we're relaxing the loading check to enable more Inference API | ||
# friendliness where sometimes, it's not at all possible to automatically | ||
# determine `weight_name`. | ||
if weight_name is None: | ||
weight_name = cls._best_guess_weight_name( | ||
pretrained_model_name_or_path_or_dict, | ||
file_extension=".safetensors", | ||
local_files_only=local_files_only, | ||
) | ||
model_file = _get_model_file( | ||
pretrained_model_name_or_path_or_dict, | ||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
local_files_only=local_files_only, | ||
token=token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
user_agent=user_agent, | ||
) | ||
state_dict = safetensors.torch.load_file(model_file, device="cpu") | ||
except (IOError, safetensors.SafetensorError) as e: | ||
if not allow_pickle: | ||
raise e | ||
# try loading non-safetensors weights | ||
model_file = None | ||
pass | ||
|
||
if model_file is None: | ||
if weight_name is None: | ||
weight_name = cls._best_guess_weight_name( | ||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only | ||
) | ||
model_file = _get_model_file( | ||
pretrained_model_name_or_path_or_dict, | ||
weights_name=weight_name or LORA_WEIGHT_NAME, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
local_files_only=local_files_only, | ||
token=token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
user_agent=user_agent, | ||
) | ||
state_dict = load_state_dict(model_file) | ||
else: | ||
state_dict = pretrained_model_name_or_path_or_dict | ||
|
||
return state_dict | ||
def _fetch_state_dict(cls, *args, **kwargs): | ||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." | ||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message) | ||
return _fetch_state_dict(*args, **kwargs) | ||
|
||
@classmethod | ||
def _best_guess_weight_name( | ||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False | ||
): | ||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if local_files_only or HF_HUB_OFFLINE: | ||
raise ValueError("When using the offline mode, you must specify a `weight_name`.") | ||
|
||
targeted_files = [] | ||
|
||
if os.path.isfile(pretrained_model_name_or_path_or_dict): | ||
return | ||
elif os.path.isdir(pretrained_model_name_or_path_or_dict): | ||
targeted_files = [ | ||
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) | ||
] | ||
else: | ||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings | ||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] | ||
if len(targeted_files) == 0: | ||
return | ||
|
||
# "scheduler" does not correspond to a LoRA checkpoint. | ||
# "optimizer" does not correspond to a LoRA checkpoint | ||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint". | ||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} | ||
targeted_files = list( | ||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) | ||
) | ||
|
||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): | ||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) | ||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): | ||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) | ||
|
||
if len(targeted_files) > 1: | ||
raise ValueError( | ||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." | ||
) | ||
weight_name = targeted_files[0] | ||
return weight_name | ||
def _best_guess_weight_name(cls, *args, **kwargs): | ||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." | ||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) | ||
return _best_guess_weight_name(*args, **kwargs) | ||
|
||
def unload_lora_weights(self): | ||
""" | ||
|
@@ -725,8 +733,6 @@ def write_lora_layers( | |
save_function: Callable, | ||
safe_serialization: bool, | ||
): | ||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE | ||
|
||
if os.path.isfile(save_directory): | ||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | ||
return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just taking it out of the class to be able to better reuse.