@@ -436,7 +436,10 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
436
436
return prompt
437
437
438
438
def load_textual_inversion (
439
- self , pretrained_model_name_or_path : Union [str , Dict [str , torch .Tensor ]], token : Optional [str ] = None , ** kwargs
439
+ self ,
440
+ pretrained_model_name_or_path : Union [str , List [str ]],
441
+ token : Optional [Union [str , List [str ]]] = None ,
442
+ ** kwargs ,
440
443
):
441
444
r"""
442
445
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
@@ -449,14 +452,20 @@ def load_textual_inversion(
449
452
</Tip>
450
453
451
454
Parameters:
452
- pretrained_model_name_or_path (`str` or `os.PathLike`):
455
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` ):
453
456
Can be either:
454
457
455
458
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
456
459
Valid model ids should have an organization name, like
457
460
`"sd-concepts-library/low-poly-hd-logos-icons"`.
458
461
- A path to a *directory* containing textual inversion weights, e.g.
459
462
`./my_text_inversion_directory/`.
463
+ - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
464
+
465
+ Or a list of those elements.
466
+ token (`str` or `List[str]`, *optional*):
467
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
468
+ list, then `token` must also be a list of equal length.
460
469
weight_name (`str`, *optional*):
461
470
Name of a custom weight file. This should be used in two cases:
462
471
@@ -576,16 +585,62 @@ def load_textual_inversion(
576
585
"framework" : "pytorch" ,
577
586
}
578
587
579
- # 1. Load textual inversion file
580
- model_file = None
581
- # Let's first try to load .safetensors weights
582
- if (use_safetensors and weight_name is None ) or (
583
- weight_name is not None and weight_name .endswith (".safetensors" )
584
- ):
585
- try :
588
+ if isinstance (pretrained_model_name_or_path , str ):
589
+ pretrained_model_name_or_paths = [pretrained_model_name_or_path ]
590
+ else :
591
+ pretrained_model_name_or_paths = pretrained_model_name_or_path
592
+
593
+ if isinstance (token , str ):
594
+ tokens = [token ]
595
+ elif token is None :
596
+ tokens = [None ] * len (pretrained_model_name_or_paths )
597
+ else :
598
+ tokens = token
599
+
600
+ if len (pretrained_model_name_or_paths ) != len (tokens ):
601
+ raise ValueError (
602
+ f"You have passed a list of models of length { len (pretrained_model_name_or_paths )} , and list of tokens of length { len (tokens )} "
603
+ f"Make sure both lists have the same length."
604
+ )
605
+
606
+ valid_tokens = [t for t in tokens if t is not None ]
607
+ if len (set (valid_tokens )) < len (valid_tokens ):
608
+ raise ValueError (f"You have passed a list of tokens that contains duplicates: { tokens } " )
609
+
610
+ token_ids_and_embeddings = []
611
+
612
+ for pretrained_model_name_or_path , token in zip (pretrained_model_name_or_paths , tokens ):
613
+ # 1. Load textual inversion file
614
+ model_file = None
615
+ # Let's first try to load .safetensors weights
616
+ if (use_safetensors and weight_name is None ) or (
617
+ weight_name is not None and weight_name .endswith (".safetensors" )
618
+ ):
619
+ try :
620
+ model_file = _get_model_file (
621
+ pretrained_model_name_or_path ,
622
+ weights_name = weight_name or TEXT_INVERSION_NAME_SAFE ,
623
+ cache_dir = cache_dir ,
624
+ force_download = force_download ,
625
+ resume_download = resume_download ,
626
+ proxies = proxies ,
627
+ local_files_only = local_files_only ,
628
+ use_auth_token = use_auth_token ,
629
+ revision = revision ,
630
+ subfolder = subfolder ,
631
+ user_agent = user_agent ,
632
+ )
633
+ state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
634
+ except Exception as e :
635
+ if not allow_pickle :
636
+ raise e
637
+
638
+ model_file = None
639
+
640
+ if model_file is None :
586
641
model_file = _get_model_file (
587
642
pretrained_model_name_or_path ,
588
- weights_name = weight_name or TEXT_INVERSION_NAME_SAFE ,
643
+ weights_name = weight_name or TEXT_INVERSION_NAME ,
589
644
cache_dir = cache_dir ,
590
645
force_download = force_download ,
591
646
resume_download = resume_download ,
@@ -596,88 +651,68 @@ def load_textual_inversion(
596
651
subfolder = subfolder ,
597
652
user_agent = user_agent ,
598
653
)
599
- state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
600
- except Exception as e :
601
- if not allow_pickle :
602
- raise e
654
+ state_dict = torch .load (model_file , map_location = "cpu" )
603
655
604
- model_file = None
656
+ # 2. Load token and embedding correcly from file
657
+ if isinstance (state_dict , torch .Tensor ):
658
+ if token is None :
659
+ raise ValueError (
660
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
661
+ )
662
+ embedding = state_dict
663
+ elif len (state_dict ) == 1 :
664
+ # diffusers
665
+ loaded_token , embedding = next (iter (state_dict .items ()))
666
+ elif "string_to_param" in state_dict :
667
+ # A1111
668
+ loaded_token = state_dict ["name" ]
669
+ embedding = state_dict ["string_to_param" ]["*" ]
670
+
671
+ if token is not None and loaded_token != token :
672
+ logger .info (f"The loaded token: { loaded_token } is overwritten by the passed token { token } ." )
673
+ else :
674
+ token = loaded_token
605
675
606
- if model_file is None :
607
- model_file = _get_model_file (
608
- pretrained_model_name_or_path ,
609
- weights_name = weight_name or TEXT_INVERSION_NAME ,
610
- cache_dir = cache_dir ,
611
- force_download = force_download ,
612
- resume_download = resume_download ,
613
- proxies = proxies ,
614
- local_files_only = local_files_only ,
615
- use_auth_token = use_auth_token ,
616
- revision = revision ,
617
- subfolder = subfolder ,
618
- user_agent = user_agent ,
619
- )
620
- state_dict = torch .load (model_file , map_location = "cpu" )
676
+ embedding = embedding .to (dtype = self .text_encoder .dtype , device = self .text_encoder .device )
621
677
622
- # 2. Load token and embedding correcly from file
623
- if isinstance ( state_dict , torch . Tensor ):
624
- if token is None :
678
+ # 3. Make sure we don't mess up the tokenizer or text encoder
679
+ vocab = self . tokenizer . get_vocab ()
680
+ if token in vocab :
625
681
raise ValueError (
626
- "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...` ."
682
+ f"Token { token } already in tokenizer vocabulary. Please choose a different token name or remove { token } and embedding from the tokenizer and text encoder ."
627
683
)
628
- embedding = state_dict
629
- elif len (state_dict ) == 1 :
630
- # diffusers
631
- loaded_token , embedding = next (iter (state_dict .items ()))
632
- elif "string_to_param" in state_dict :
633
- # A1111
634
- loaded_token = state_dict ["name" ]
635
- embedding = state_dict ["string_to_param" ]["*" ]
636
-
637
- if token is not None and loaded_token != token :
638
- logger .warn (f"The loaded token: { loaded_token } is overwritten by the passed token { token } ." )
639
- else :
640
- token = loaded_token
641
-
642
- embedding = embedding .to (dtype = self .text_encoder .dtype , device = self .text_encoder .device )
684
+ elif f"{ token } _1" in vocab :
685
+ multi_vector_tokens = [token ]
686
+ i = 1
687
+ while f"{ token } _{ i } " in self .tokenizer .added_tokens_encoder :
688
+ multi_vector_tokens .append (f"{ token } _{ i } " )
689
+ i += 1
643
690
644
- # 3. Make sure we don't mess up the tokenizer or text encoder
645
- vocab = self .tokenizer .get_vocab ()
646
- if token in vocab :
647
- raise ValueError (
648
- f"Token { token } already in tokenizer vocabulary. Please choose a different token name or remove { token } and embedding from the tokenizer and text encoder."
649
- )
650
- elif f"{ token } _1" in vocab :
651
- multi_vector_tokens = [token ]
652
- i = 1
653
- while f"{ token } _{ i } " in self .tokenizer .added_tokens_encoder :
654
- multi_vector_tokens .append (f"{ token } _{ i } " )
655
- i += 1
691
+ raise ValueError (
692
+ f"Multi-vector Token { multi_vector_tokens } already in tokenizer vocabulary. Please choose a different token name or remove the { multi_vector_tokens } and embedding from the tokenizer and text encoder."
693
+ )
656
694
657
- raise ValueError (
658
- f"Multi-vector Token { multi_vector_tokens } already in tokenizer vocabulary. Please choose a different token name or remove the { multi_vector_tokens } and embedding from the tokenizer and text encoder."
659
- )
695
+ is_multi_vector = len (embedding .shape ) > 1 and embedding .shape [0 ] > 1
660
696
661
- is_multi_vector = len (embedding .shape ) > 1 and embedding .shape [0 ] > 1
697
+ if is_multi_vector :
698
+ tokens = [token ] + [f"{ token } _{ i } " for i in range (1 , embedding .shape [0 ])]
699
+ embeddings = [e for e in embedding ] # noqa: C416
700
+ else :
701
+ tokens = [token ]
702
+ embeddings = [embedding [0 ]] if len (embedding .shape ) > 1 else [embedding ]
662
703
663
- if is_multi_vector :
664
- tokens = [token ] + [f"{ token } _{ i } " for i in range (1 , embedding .shape [0 ])]
665
- embeddings = [e for e in embedding ] # noqa: C416
666
- else :
667
- tokens = [token ]
668
- embeddings = [embedding [0 ]] if len (embedding .shape ) > 1 else [embedding ]
704
+ # add tokens and get ids
705
+ self .tokenizer .add_tokens (tokens )
706
+ token_ids = self .tokenizer .convert_tokens_to_ids (tokens )
707
+ token_ids_and_embeddings += zip (token_ids , embeddings )
669
708
670
- # add tokens and get ids
671
- self .tokenizer .add_tokens (tokens )
672
- token_ids = self .tokenizer .convert_tokens_to_ids (tokens )
709
+ logger .info (f"Loaded textual inversion embedding for { token } ." )
673
710
674
- # resize token embeddings and set new embeddings
711
+ # resize token embeddings and set all new embeddings
675
712
self .text_encoder .resize_token_embeddings (len (self .tokenizer ))
676
- for token_id , embedding in zip ( token_ids , embeddings ) :
713
+ for token_id , embedding in token_ids_and_embeddings :
677
714
self .text_encoder .get_input_embeddings ().weight .data [token_id ] = embedding
678
715
679
- logger .info (f"Loaded textual inversion embedding for { token } ." )
680
-
681
716
682
717
class LoraLoaderMixin :
683
718
r"""
0 commit comments