Skip to content

Commit ce93063

Browse files
piEspositoEandrewJonesana-tamaispatrickvonplatensayakpaul
authored
add load textual inversion embeddings to stable diffusion (huggingface#2009)
* add load textual inversion embeddings draft * fix quality * fix typo * make fix copies * move to textual inversion mixin * make it accept from sd-concept library * accept list of paths to embeddings * fix styling of stable diffusion pipeline * add dummy TextualInversionMixin * add docstring to textualinversionmixin * add load textual inversion embeddings draft * fix quality * fix typo * make fix copies * move to textual inversion mixin * make it accept from sd-concept library * accept list of paths to embeddings * fix styling of stable diffusion pipeline * add dummy TextualInversionMixin * add docstring to textualinversionmixin * add case for parsing embedding from auto1111 UI format Co-authored-by: Evan Jones <[email protected]> Co-authored-by: Ana Tamais <[email protected]> * fix style after rebase * move textual inversion mixin to loaders * move mixin inheritance to DiffusionPipeline from StableDiffusionPipeline) * update dummy class name * addressed allo comments * fix old dangling import * fix style * proposal * remove bogus * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Will Berman <[email protected]> * finish * make style * up * fix code quality * fix code quality - again * fix code quality - 3 * fix alt diffusion code quality * fix model editing pipeline * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Finish --------- Co-authored-by: Evan Jones <[email protected]> Co-authored-by: Ana Tamais <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Will Berman <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 938ec58 commit ce93063

26 files changed

+648
-168
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
except OptionalDependencyNotAvailable:
110110
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
111111
else:
112+
from .loaders import TextualInversionLoaderMixin
112113
from .pipelines import (
113114
AltDiffusionImg2ImgPipeline,
114115
AltDiffusionPipeline,

loaders.py

Lines changed: 284 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,38 @@
1313
# limitations under the License.
1414
import os
1515
from collections import defaultdict
16-
from typing import Callable, Dict, Union
16+
from typing import Callable, Dict, List, Optional, Union
1717

1818
import torch
1919

2020
from .models.attention_processor import LoRAAttnProcessor
21-
from .models.modeling_utils import _get_model_file
22-
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
21+
from .utils import (
22+
DIFFUSERS_CACHE,
23+
HF_HUB_OFFLINE,
24+
_get_model_file,
25+
deprecate,
26+
is_safetensors_available,
27+
is_transformers_available,
28+
logging,
29+
)
2330

2431

2532
if is_safetensors_available():
2633
import safetensors
2734

35+
if is_transformers_available():
36+
from transformers import PreTrainedModel, PreTrainedTokenizer
37+
2838

2939
logger = logging.get_logger(__name__)
3040

3141

3242
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
3343
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
3444

45+
TEXT_INVERSION_NAME = "learned_embeds.bin"
46+
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
47+
3548

3649
class AttnProcsLayers(torch.nn.Module):
3750
def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -123,13 +136,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
123136
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
124137
models](https://huggingface.co/docs/hub/models-gated#gated-models).
125138
126-
</Tip>
127-
128-
<Tip>
129-
130-
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
131-
this method in a firewalled environment.
132-
133139
</Tip>
134140
"""
135141

@@ -292,5 +298,272 @@ def save_function(weights, filename):
292298

293299
# Save the model
294300
save_function(state_dict, os.path.join(save_directory, weight_name))
295-
296301
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
302+
303+
304+
class TextualInversionLoaderMixin:
305+
r"""
306+
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
307+
"""
308+
309+
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
310+
r"""
311+
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
312+
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
313+
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
314+
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
315+
316+
Parameters:
317+
prompt (`str` or list of `str`):
318+
The prompt or prompts to guide the image generation.
319+
tokenizer (`PreTrainedTokenizer`):
320+
The tokenizer responsible for encoding the prompt into input tokens.
321+
322+
Returns:
323+
`str` or list of `str`: The converted prompt
324+
"""
325+
if not isinstance(prompt, List):
326+
prompts = [prompt]
327+
else:
328+
prompts = prompt
329+
330+
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
331+
332+
if not isinstance(prompt, List):
333+
return prompts[0]
334+
335+
return prompts
336+
337+
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
338+
r"""
339+
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
340+
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
341+
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
342+
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
343+
344+
Parameters:
345+
prompt (`str`):
346+
The prompt to guide the image generation.
347+
tokenizer (`PreTrainedTokenizer`):
348+
The tokenizer responsible for encoding the prompt into input tokens.
349+
350+
Returns:
351+
`str`: The converted prompt
352+
"""
353+
tokens = tokenizer.tokenize(prompt)
354+
for token in tokens:
355+
if token in tokenizer.added_tokens_encoder:
356+
replacement = token
357+
i = 1
358+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
359+
replacement += f"{token}_{i}"
360+
i += 1
361+
362+
prompt = prompt.replace(token, replacement)
363+
364+
return prompt
365+
366+
def load_textual_inversion(
367+
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
368+
):
369+
r"""
370+
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
371+
`Automatic1111` formats are supported.
372+
373+
<Tip warning={true}>
374+
375+
This function is experimental and might change in the future.
376+
377+
</Tip>
378+
379+
Parameters:
380+
pretrained_model_name_or_path (`str` or `os.PathLike`):
381+
Can be either:
382+
383+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
384+
Valid model ids should have an organization name, like
385+
`"sd-concepts-library/low-poly-hd-logos-icons"`.
386+
- A path to a *directory* containing textual inversion weights, e.g.
387+
`./my_text_inversion_directory/`.
388+
weight_name (`str`, *optional*):
389+
Name of a custom weight file. This should be used in two cases:
390+
391+
- The saved textual inversion file is in `diffusers` format, but was saved under a specific weight
392+
name, such as `text_inv.bin`.
393+
- The saved textual inversion file is in the "Automatic1111" form.
394+
cache_dir (`Union[str, os.PathLike]`, *optional*):
395+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
396+
standard cache should not be used.
397+
force_download (`bool`, *optional*, defaults to `False`):
398+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
399+
cached versions if they exist.
400+
resume_download (`bool`, *optional*, defaults to `False`):
401+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
402+
file exists.
403+
proxies (`Dict[str, str]`, *optional*):
404+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
405+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
406+
local_files_only(`bool`, *optional*, defaults to `False`):
407+
Whether or not to only look at local files (i.e., do not try to download the model).
408+
use_auth_token (`str` or *bool*, *optional*):
409+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
410+
when running `diffusers-cli login` (stored in `~/.huggingface`).
411+
revision (`str`, *optional*, defaults to `"main"`):
412+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
413+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
414+
identifier allowed by git.
415+
subfolder (`str`, *optional*, defaults to `""`):
416+
In case the relevant files are located inside a subfolder of the model repo (either remote in
417+
huggingface.co or downloaded locally), you can specify the folder name here.
418+
419+
mirror (`str`, *optional*):
420+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
421+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
422+
Please refer to the mirror site for more information.
423+
424+
<Tip>
425+
426+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
427+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
428+
429+
</Tip>
430+
"""
431+
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
432+
raise ValueError(
433+
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
434+
f" `{self.load_textual_inversion.__name__}`"
435+
)
436+
437+
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
438+
raise ValueError(
439+
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
440+
f" `{self.load_textual_inversion.__name__}`"
441+
)
442+
443+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
444+
force_download = kwargs.pop("force_download", False)
445+
resume_download = kwargs.pop("resume_download", False)
446+
proxies = kwargs.pop("proxies", None)
447+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
448+
use_auth_token = kwargs.pop("use_auth_token", None)
449+
revision = kwargs.pop("revision", None)
450+
subfolder = kwargs.pop("subfolder", None)
451+
weight_name = kwargs.pop("weight_name", None)
452+
use_safetensors = kwargs.pop("use_safetensors", None)
453+
454+
if use_safetensors and not is_safetensors_available():
455+
raise ValueError(
456+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
457+
)
458+
459+
allow_pickle = False
460+
if use_safetensors is None:
461+
use_safetensors = is_safetensors_available()
462+
allow_pickle = True
463+
464+
user_agent = {
465+
"file_type": "text_inversion",
466+
"framework": "pytorch",
467+
}
468+
469+
# 1. Load textual inversion file
470+
model_file = None
471+
# Let's first try to load .safetensors weights
472+
if (use_safetensors and weight_name is None) or (
473+
weight_name is not None and weight_name.endswith(".safetensors")
474+
):
475+
try:
476+
model_file = _get_model_file(
477+
pretrained_model_name_or_path,
478+
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
479+
cache_dir=cache_dir,
480+
force_download=force_download,
481+
resume_download=resume_download,
482+
proxies=proxies,
483+
local_files_only=local_files_only,
484+
use_auth_token=use_auth_token,
485+
revision=revision,
486+
subfolder=subfolder,
487+
user_agent=user_agent,
488+
)
489+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
490+
except Exception as e:
491+
if not allow_pickle:
492+
raise e
493+
494+
model_file = None
495+
496+
if model_file is None:
497+
model_file = _get_model_file(
498+
pretrained_model_name_or_path,
499+
weights_name=weight_name or TEXT_INVERSION_NAME,
500+
cache_dir=cache_dir,
501+
force_download=force_download,
502+
resume_download=resume_download,
503+
proxies=proxies,
504+
local_files_only=local_files_only,
505+
use_auth_token=use_auth_token,
506+
revision=revision,
507+
subfolder=subfolder,
508+
user_agent=user_agent,
509+
)
510+
state_dict = torch.load(model_file, map_location="cpu")
511+
512+
# 2. Load token and embedding correcly from file
513+
if isinstance(state_dict, torch.Tensor):
514+
if token is None:
515+
raise ValueError(
516+
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
517+
)
518+
embedding = state_dict
519+
elif len(state_dict) == 1:
520+
# diffusers
521+
loaded_token, embedding = next(iter(state_dict.items()))
522+
elif "string_to_param" in state_dict:
523+
# A1111
524+
loaded_token = state_dict["name"]
525+
embedding = state_dict["string_to_param"]["*"]
526+
527+
if token is not None and loaded_token != token:
528+
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
529+
else:
530+
token = loaded_token
531+
532+
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
533+
534+
# 3. Make sure we don't mess up the tokenizer or text encoder
535+
vocab = self.tokenizer.get_vocab()
536+
if token in vocab:
537+
raise ValueError(
538+
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
539+
)
540+
elif f"{token}_1" in vocab:
541+
multi_vector_tokens = [token]
542+
i = 1
543+
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
544+
multi_vector_tokens.append(f"{token}_{i}")
545+
i += 1
546+
547+
raise ValueError(
548+
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
549+
)
550+
551+
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
552+
553+
if is_multi_vector:
554+
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
555+
embeddings = [e for e in embedding] # noqa: C416
556+
else:
557+
tokens = [token]
558+
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
559+
560+
# add tokens and get ids
561+
self.tokenizer.add_tokens(tokens)
562+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
563+
564+
# resize token embeddings and set new embeddings
565+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
566+
for token_id, embedding in zip(token_ids, embeddings):
567+
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
568+
569+
logger.info("Loaded textual inversion embedding for {token}.")

0 commit comments

Comments
 (0)