Skip to content

Commit 799f5b4

Browse files
[Feat] Enable State Dict For Textual Inversion Loader (#3439)
* enable state dict for textual inversion loader * Empty-Commit | restart CI * Empty-Commit | restart CI * Empty-Commit | restart CI * Empty-Commit | restart CI * add tests * fix tests * fix tests * fix tests --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 07ef485 commit 799f5b4

File tree

2 files changed

+97
-33
lines changed

2 files changed

+97
-33
lines changed

src/diffusers/loaders.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
470470

471471
def load_textual_inversion(
472472
self,
473-
pretrained_model_name_or_path: Union[str, List[str]],
473+
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
474474
token: Optional[Union[str, List[str]]] = None,
475475
**kwargs,
476476
):
@@ -485,7 +485,7 @@ def load_textual_inversion(
485485
</Tip>
486486
487487
Parameters:
488-
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
488+
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
489489
Can be either:
490490
491491
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
@@ -494,6 +494,8 @@ def load_textual_inversion(
494494
- A path to a *directory* containing textual inversion weights, e.g.
495495
`./my_text_inversion_directory/`.
496496
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
497+
- A [torch state
498+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
497499
498500
Or a list of those elements.
499501
token (`str` or `List[str]`, *optional*):
@@ -618,7 +620,7 @@ def load_textual_inversion(
618620
"framework": "pytorch",
619621
}
620622

621-
if isinstance(pretrained_model_name_or_path, str):
623+
if not isinstance(pretrained_model_name_or_path, list):
622624
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
623625
else:
624626
pretrained_model_name_or_paths = pretrained_model_name_or_path
@@ -643,16 +645,38 @@ def load_textual_inversion(
643645
token_ids_and_embeddings = []
644646

645647
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
646-
# 1. Load textual inversion file
647-
model_file = None
648-
# Let's first try to load .safetensors weights
649-
if (use_safetensors and weight_name is None) or (
650-
weight_name is not None and weight_name.endswith(".safetensors")
651-
):
652-
try:
648+
if not isinstance(pretrained_model_name_or_path, dict):
649+
# 1. Load textual inversion file
650+
model_file = None
651+
# Let's first try to load .safetensors weights
652+
if (use_safetensors and weight_name is None) or (
653+
weight_name is not None and weight_name.endswith(".safetensors")
654+
):
655+
try:
656+
model_file = _get_model_file(
657+
pretrained_model_name_or_path,
658+
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
659+
cache_dir=cache_dir,
660+
force_download=force_download,
661+
resume_download=resume_download,
662+
proxies=proxies,
663+
local_files_only=local_files_only,
664+
use_auth_token=use_auth_token,
665+
revision=revision,
666+
subfolder=subfolder,
667+
user_agent=user_agent,
668+
)
669+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
670+
except Exception as e:
671+
if not allow_pickle:
672+
raise e
673+
674+
model_file = None
675+
676+
if model_file is None:
653677
model_file = _get_model_file(
654678
pretrained_model_name_or_path,
655-
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
679+
weights_name=weight_name or TEXT_INVERSION_NAME,
656680
cache_dir=cache_dir,
657681
force_download=force_download,
658682
resume_download=resume_download,
@@ -663,28 +687,9 @@ def load_textual_inversion(
663687
subfolder=subfolder,
664688
user_agent=user_agent,
665689
)
666-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
667-
except Exception as e:
668-
if not allow_pickle:
669-
raise e
670-
671-
model_file = None
672-
673-
if model_file is None:
674-
model_file = _get_model_file(
675-
pretrained_model_name_or_path,
676-
weights_name=weight_name or TEXT_INVERSION_NAME,
677-
cache_dir=cache_dir,
678-
force_download=force_download,
679-
resume_download=resume_download,
680-
proxies=proxies,
681-
local_files_only=local_files_only,
682-
use_auth_token=use_auth_token,
683-
revision=revision,
684-
subfolder=subfolder,
685-
user_agent=user_agent,
686-
)
687-
state_dict = torch.load(model_file, map_location="cpu")
690+
state_dict = torch.load(model_file, map_location="cpu")
691+
else:
692+
state_dict = pretrained_model_name_or_path
688693

689694
# 2. Load token and embedding correcly from file
690695
loaded_token = None

tests/pipelines/test_pipelines.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,65 @@ def test_text_inversion_download(self):
663663
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
664664
assert out.shape == (1, 128, 128, 3)
665665

666+
# single token state dict load
667+
ten = {"<x>": torch.ones((32,))}
668+
pipe.load_textual_inversion(ten)
669+
670+
token = pipe.tokenizer.convert_tokens_to_ids("<x>")
671+
assert token == num_tokens + 10, "Added token must be at spot `num_tokens`"
672+
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
673+
assert pipe._maybe_convert_prompt("<x>", pipe.tokenizer) == "<x>"
674+
675+
prompt = "hey <x>"
676+
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
677+
assert out.shape == (1, 128, 128, 3)
678+
679+
# multi embedding state dict load
680+
ten1 = {"<xxxxx>": torch.ones((32,))}
681+
ten2 = {"<xxxxxx>": 2 * torch.ones((1, 32))}
682+
683+
pipe.load_textual_inversion([ten1, ten2])
684+
685+
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxx>")
686+
assert token == num_tokens + 11, "Added token must be at spot `num_tokens`"
687+
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
688+
assert pipe._maybe_convert_prompt("<xxxxx>", pipe.tokenizer) == "<xxxxx>"
689+
690+
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxxx>")
691+
assert token == num_tokens + 12, "Added token must be at spot `num_tokens`"
692+
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
693+
assert pipe._maybe_convert_prompt("<xxxxxx>", pipe.tokenizer) == "<xxxxxx>"
694+
695+
prompt = "hey <xxxxx> <xxxxxx>"
696+
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
697+
assert out.shape == (1, 128, 128, 3)
698+
699+
# auto1111 multi-token state dict load
700+
ten = {
701+
"string_to_param": {
702+
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
703+
},
704+
"name": "<xxxx>",
705+
}
706+
707+
pipe.load_textual_inversion(ten)
708+
709+
token = pipe.tokenizer.convert_tokens_to_ids("<xxxx>")
710+
token_1 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_1")
711+
token_2 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_2")
712+
713+
assert token == num_tokens + 13, "Added token must be at spot `num_tokens`"
714+
assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`"
715+
assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`"
716+
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
717+
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
718+
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
719+
assert pipe._maybe_convert_prompt("<xxxx>", pipe.tokenizer) == "<xxxx> <xxxx>_1 <xxxx>_2"
720+
721+
prompt = "hey <xxxx>"
722+
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
723+
assert out.shape == (1, 128, 128, 3)
724+
666725
def test_download_ignore_files(self):
667726
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
668727
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)