Skip to content

Commit 8e0cb4e

Browse files
[Feat] Enable State Dict For Textual Inversion Loader (huggingface#3439)
* enable state dict for textual inversion loader * Empty-Commit | restart CI * Empty-Commit | restart CI * Empty-Commit | restart CI * Empty-Commit | restart CI * add tests * fix tests * fix tests * fix tests --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7bbc036 commit 8e0cb4e

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

loaders.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
470470

471471
def load_textual_inversion(
472472
self,
473-
pretrained_model_name_or_path: Union[str, List[str]],
473+
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
474474
token: Optional[Union[str, List[str]]] = None,
475475
**kwargs,
476476
):
@@ -485,7 +485,7 @@ def load_textual_inversion(
485485
</Tip>
486486
487487
Parameters:
488-
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
488+
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
489489
Can be either:
490490
491491
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
@@ -494,6 +494,8 @@ def load_textual_inversion(
494494
- A path to a *directory* containing textual inversion weights, e.g.
495495
`./my_text_inversion_directory/`.
496496
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
497+
- A [torch state
498+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
497499
498500
Or a list of those elements.
499501
token (`str` or `List[str]`, *optional*):
@@ -618,7 +620,7 @@ def load_textual_inversion(
618620
"framework": "pytorch",
619621
}
620622

621-
if isinstance(pretrained_model_name_or_path, str):
623+
if not isinstance(pretrained_model_name_or_path, list):
622624
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
623625
else:
624626
pretrained_model_name_or_paths = pretrained_model_name_or_path
@@ -643,16 +645,38 @@ def load_textual_inversion(
643645
token_ids_and_embeddings = []
644646

645647
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
646-
# 1. Load textual inversion file
647-
model_file = None
648-
# Let's first try to load .safetensors weights
649-
if (use_safetensors and weight_name is None) or (
650-
weight_name is not None and weight_name.endswith(".safetensors")
651-
):
652-
try:
648+
if not isinstance(pretrained_model_name_or_path, dict):
649+
# 1. Load textual inversion file
650+
model_file = None
651+
# Let's first try to load .safetensors weights
652+
if (use_safetensors and weight_name is None) or (
653+
weight_name is not None and weight_name.endswith(".safetensors")
654+
):
655+
try:
656+
model_file = _get_model_file(
657+
pretrained_model_name_or_path,
658+
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
659+
cache_dir=cache_dir,
660+
force_download=force_download,
661+
resume_download=resume_download,
662+
proxies=proxies,
663+
local_files_only=local_files_only,
664+
use_auth_token=use_auth_token,
665+
revision=revision,
666+
subfolder=subfolder,
667+
user_agent=user_agent,
668+
)
669+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
670+
except Exception as e:
671+
if not allow_pickle:
672+
raise e
673+
674+
model_file = None
675+
676+
if model_file is None:
653677
model_file = _get_model_file(
654678
pretrained_model_name_or_path,
655-
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
679+
weights_name=weight_name or TEXT_INVERSION_NAME,
656680
cache_dir=cache_dir,
657681
force_download=force_download,
658682
resume_download=resume_download,
@@ -663,28 +687,9 @@ def load_textual_inversion(
663687
subfolder=subfolder,
664688
user_agent=user_agent,
665689
)
666-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
667-
except Exception as e:
668-
if not allow_pickle:
669-
raise e
670-
671-
model_file = None
672-
673-
if model_file is None:
674-
model_file = _get_model_file(
675-
pretrained_model_name_or_path,
676-
weights_name=weight_name or TEXT_INVERSION_NAME,
677-
cache_dir=cache_dir,
678-
force_download=force_download,
679-
resume_download=resume_download,
680-
proxies=proxies,
681-
local_files_only=local_files_only,
682-
use_auth_token=use_auth_token,
683-
revision=revision,
684-
subfolder=subfolder,
685-
user_agent=user_agent,
686-
)
687-
state_dict = torch.load(model_file, map_location="cpu")
690+
state_dict = torch.load(model_file, map_location="cpu")
691+
else:
692+
state_dict = pretrained_model_name_or_path
688693

689694
# 2. Load token and embedding correcly from file
690695
loaded_token = None

0 commit comments

Comments
 (0)