Skip to content

[Feat] Enable State Dict For Textual Inversion Loader #3439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

71 changes: 38 additions & 33 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):

def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str]],
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
**kwargs,
):
Expand All @@ -485,7 +485,7 @@ def load_textual_inversion(
</Tip>

Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either:

- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Expand All @@ -494,6 +494,8 @@ def load_textual_inversion(
- A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

Or a list of those elements.
token (`str` or `List[str]`, *optional*):
Expand Down Expand Up @@ -618,7 +620,7 @@ def load_textual_inversion(
"framework": "pytorch",
}

if isinstance(pretrained_model_name_or_path, str):
if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else:
pretrained_model_name_or_paths = pretrained_model_name_or_path
Expand All @@ -643,16 +645,38 @@ def load_textual_inversion(
token_ids_and_embeddings = []

for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
if not isinstance(pretrained_model_name_or_path, dict):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -663,28 +687,9 @@ def load_textual_inversion(
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path

# 2. Load token and embedding correcly from file
loaded_token = None
Expand Down