Skip to content

Commit 3d8b3d7

Browse files
Batched load of textual inversions (#3277)
* Batched load of textual inversions - Only call resize_token_embeddings once per batch as it is the most expensive operation - Allow pretrained_model_name_or_path and token to be an optional list - Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function - Add comment that single files (e.g. .pt/.safetensors) are supported - Add comment for token parameter - Convert token override log message from warning to info * Update src/diffusers/loaders.py Check for duplicate tokens Co-authored-by: Patrick von Platen <[email protected]> * Update condition for None tokens --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 0ffac97 commit 3d8b3d7

File tree

2 files changed

+138
-78
lines changed

2 files changed

+138
-78
lines changed

src/diffusers/loaders.py

Lines changed: 113 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,10 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
436436
return prompt
437437

438438
def load_textual_inversion(
439-
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
439+
self,
440+
pretrained_model_name_or_path: Union[str, List[str]],
441+
token: Optional[Union[str, List[str]]] = None,
442+
**kwargs,
440443
):
441444
r"""
442445
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
@@ -449,14 +452,20 @@ def load_textual_inversion(
449452
</Tip>
450453
451454
Parameters:
452-
pretrained_model_name_or_path (`str` or `os.PathLike`):
455+
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
453456
Can be either:
454457
455458
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
456459
Valid model ids should have an organization name, like
457460
`"sd-concepts-library/low-poly-hd-logos-icons"`.
458461
- A path to a *directory* containing textual inversion weights, e.g.
459462
`./my_text_inversion_directory/`.
463+
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
464+
465+
Or a list of those elements.
466+
token (`str` or `List[str]`, *optional*):
467+
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
468+
list, then `token` must also be a list of equal length.
460469
weight_name (`str`, *optional*):
461470
Name of a custom weight file. This should be used in two cases:
462471
@@ -576,16 +585,62 @@ def load_textual_inversion(
576585
"framework": "pytorch",
577586
}
578587

579-
# 1. Load textual inversion file
580-
model_file = None
581-
# Let's first try to load .safetensors weights
582-
if (use_safetensors and weight_name is None) or (
583-
weight_name is not None and weight_name.endswith(".safetensors")
584-
):
585-
try:
588+
if isinstance(pretrained_model_name_or_path, str):
589+
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
590+
else:
591+
pretrained_model_name_or_paths = pretrained_model_name_or_path
592+
593+
if isinstance(token, str):
594+
tokens = [token]
595+
elif token is None:
596+
tokens = [None] * len(pretrained_model_name_or_paths)
597+
else:
598+
tokens = token
599+
600+
if len(pretrained_model_name_or_paths) != len(tokens):
601+
raise ValueError(
602+
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
603+
f"Make sure both lists have the same length."
604+
)
605+
606+
valid_tokens = [t for t in tokens if t is not None]
607+
if len(set(valid_tokens)) < len(valid_tokens):
608+
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
609+
610+
token_ids_and_embeddings = []
611+
612+
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
613+
# 1. Load textual inversion file
614+
model_file = None
615+
# Let's first try to load .safetensors weights
616+
if (use_safetensors and weight_name is None) or (
617+
weight_name is not None and weight_name.endswith(".safetensors")
618+
):
619+
try:
620+
model_file = _get_model_file(
621+
pretrained_model_name_or_path,
622+
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
623+
cache_dir=cache_dir,
624+
force_download=force_download,
625+
resume_download=resume_download,
626+
proxies=proxies,
627+
local_files_only=local_files_only,
628+
use_auth_token=use_auth_token,
629+
revision=revision,
630+
subfolder=subfolder,
631+
user_agent=user_agent,
632+
)
633+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
634+
except Exception as e:
635+
if not allow_pickle:
636+
raise e
637+
638+
model_file = None
639+
640+
if model_file is None:
586641
model_file = _get_model_file(
587642
pretrained_model_name_or_path,
588-
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
643+
weights_name=weight_name or TEXT_INVERSION_NAME,
589644
cache_dir=cache_dir,
590645
force_download=force_download,
591646
resume_download=resume_download,
@@ -596,88 +651,68 @@ def load_textual_inversion(
596651
subfolder=subfolder,
597652
user_agent=user_agent,
598653
)
599-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
600-
except Exception as e:
601-
if not allow_pickle:
602-
raise e
654+
state_dict = torch.load(model_file, map_location="cpu")
603655

604-
model_file = None
656+
# 2. Load token and embedding correcly from file
657+
if isinstance(state_dict, torch.Tensor):
658+
if token is None:
659+
raise ValueError(
660+
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
661+
)
662+
embedding = state_dict
663+
elif len(state_dict) == 1:
664+
# diffusers
665+
loaded_token, embedding = next(iter(state_dict.items()))
666+
elif "string_to_param" in state_dict:
667+
# A1111
668+
loaded_token = state_dict["name"]
669+
embedding = state_dict["string_to_param"]["*"]
670+
671+
if token is not None and loaded_token != token:
672+
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
673+
else:
674+
token = loaded_token
605675

606-
if model_file is None:
607-
model_file = _get_model_file(
608-
pretrained_model_name_or_path,
609-
weights_name=weight_name or TEXT_INVERSION_NAME,
610-
cache_dir=cache_dir,
611-
force_download=force_download,
612-
resume_download=resume_download,
613-
proxies=proxies,
614-
local_files_only=local_files_only,
615-
use_auth_token=use_auth_token,
616-
revision=revision,
617-
subfolder=subfolder,
618-
user_agent=user_agent,
619-
)
620-
state_dict = torch.load(model_file, map_location="cpu")
676+
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
621677

622-
# 2. Load token and embedding correcly from file
623-
if isinstance(state_dict, torch.Tensor):
624-
if token is None:
678+
# 3. Make sure we don't mess up the tokenizer or text encoder
679+
vocab = self.tokenizer.get_vocab()
680+
if token in vocab:
625681
raise ValueError(
626-
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
682+
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
627683
)
628-
embedding = state_dict
629-
elif len(state_dict) == 1:
630-
# diffusers
631-
loaded_token, embedding = next(iter(state_dict.items()))
632-
elif "string_to_param" in state_dict:
633-
# A1111
634-
loaded_token = state_dict["name"]
635-
embedding = state_dict["string_to_param"]["*"]
636-
637-
if token is not None and loaded_token != token:
638-
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
639-
else:
640-
token = loaded_token
641-
642-
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
684+
elif f"{token}_1" in vocab:
685+
multi_vector_tokens = [token]
686+
i = 1
687+
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
688+
multi_vector_tokens.append(f"{token}_{i}")
689+
i += 1
643690

644-
# 3. Make sure we don't mess up the tokenizer or text encoder
645-
vocab = self.tokenizer.get_vocab()
646-
if token in vocab:
647-
raise ValueError(
648-
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
649-
)
650-
elif f"{token}_1" in vocab:
651-
multi_vector_tokens = [token]
652-
i = 1
653-
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
654-
multi_vector_tokens.append(f"{token}_{i}")
655-
i += 1
691+
raise ValueError(
692+
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
693+
)
656694

657-
raise ValueError(
658-
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
659-
)
695+
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
660696

661-
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
697+
if is_multi_vector:
698+
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
699+
embeddings = [e for e in embedding] # noqa: C416
700+
else:
701+
tokens = [token]
702+
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
662703

663-
if is_multi_vector:
664-
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
665-
embeddings = [e for e in embedding] # noqa: C416
666-
else:
667-
tokens = [token]
668-
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
704+
# add tokens and get ids
705+
self.tokenizer.add_tokens(tokens)
706+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
707+
token_ids_and_embeddings += zip(token_ids, embeddings)
669708

670-
# add tokens and get ids
671-
self.tokenizer.add_tokens(tokens)
672-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
709+
logger.info(f"Loaded textual inversion embedding for {token}.")
673710

674-
# resize token embeddings and set new embeddings
711+
# resize token embeddings and set all new embeddings
675712
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
676-
for token_id, embedding in zip(token_ids, embeddings):
713+
for token_id, embedding in token_ids_and_embeddings:
677714
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
678715

679-
logger.info(f"Loaded textual inversion embedding for {token}.")
680-
681716

682717
class LoraLoaderMixin:
683718
r"""

tests/pipelines/test_pipelines.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,31 @@ def test_text_inversion_download(self):
575575
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
576576
assert out.shape == (1, 128, 128, 3)
577577

578+
# multi embedding load
579+
with tempfile.TemporaryDirectory() as tmpdirname1:
580+
with tempfile.TemporaryDirectory() as tmpdirname2:
581+
ten = {"<*****>": torch.ones((32,))}
582+
torch.save(ten, os.path.join(tmpdirname1, "learned_embeds.bin"))
583+
584+
ten = {"<******>": 2 * torch.ones((1, 32))}
585+
torch.save(ten, os.path.join(tmpdirname2, "learned_embeds.bin"))
586+
587+
pipe.load_textual_inversion([tmpdirname1, tmpdirname2])
588+
589+
token = pipe.tokenizer.convert_tokens_to_ids("<*****>")
590+
assert token == num_tokens + 8, "Added token must be at spot `num_tokens`"
591+
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
592+
assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>"
593+
594+
token = pipe.tokenizer.convert_tokens_to_ids("<******>")
595+
assert token == num_tokens + 9, "Added token must be at spot `num_tokens`"
596+
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
597+
assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>"
598+
599+
prompt = "hey <*****> <******>"
600+
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
601+
assert out.shape == (1, 128, 128, 3)
602+
578603
def test_download_ignore_files(self):
579604
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
580605
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)