@@ -470,7 +470,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
470
470
471
471
def load_textual_inversion (
472
472
self ,
473
- pretrained_model_name_or_path : Union [str , List [str ]],
473
+ pretrained_model_name_or_path : Union [str , List [str ], Dict [ str , torch . Tensor ], List [ Dict [ str , torch . Tensor ]] ],
474
474
token : Optional [Union [str , List [str ]]] = None ,
475
475
** kwargs ,
476
476
):
@@ -485,7 +485,7 @@ def load_textual_inversion(
485
485
</Tip>
486
486
487
487
Parameters:
488
- pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
488
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]` ):
489
489
Can be either:
490
490
491
491
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
@@ -494,6 +494,8 @@ def load_textual_inversion(
494
494
- A path to a *directory* containing textual inversion weights, e.g.
495
495
`./my_text_inversion_directory/`.
496
496
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
497
+ - A [torch state
498
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
497
499
498
500
Or a list of those elements.
499
501
token (`str` or `List[str]`, *optional*):
@@ -618,7 +620,7 @@ def load_textual_inversion(
618
620
"framework" : "pytorch" ,
619
621
}
620
622
621
- if isinstance (pretrained_model_name_or_path , str ):
623
+ if not isinstance (pretrained_model_name_or_path , list ):
622
624
pretrained_model_name_or_paths = [pretrained_model_name_or_path ]
623
625
else :
624
626
pretrained_model_name_or_paths = pretrained_model_name_or_path
@@ -643,16 +645,38 @@ def load_textual_inversion(
643
645
token_ids_and_embeddings = []
644
646
645
647
for pretrained_model_name_or_path , token in zip (pretrained_model_name_or_paths , tokens ):
646
- # 1. Load textual inversion file
647
- model_file = None
648
- # Let's first try to load .safetensors weights
649
- if (use_safetensors and weight_name is None ) or (
650
- weight_name is not None and weight_name .endswith (".safetensors" )
651
- ):
652
- try :
648
+ if not isinstance (pretrained_model_name_or_path , dict ):
649
+ # 1. Load textual inversion file
650
+ model_file = None
651
+ # Let's first try to load .safetensors weights
652
+ if (use_safetensors and weight_name is None ) or (
653
+ weight_name is not None and weight_name .endswith (".safetensors" )
654
+ ):
655
+ try :
656
+ model_file = _get_model_file (
657
+ pretrained_model_name_or_path ,
658
+ weights_name = weight_name or TEXT_INVERSION_NAME_SAFE ,
659
+ cache_dir = cache_dir ,
660
+ force_download = force_download ,
661
+ resume_download = resume_download ,
662
+ proxies = proxies ,
663
+ local_files_only = local_files_only ,
664
+ use_auth_token = use_auth_token ,
665
+ revision = revision ,
666
+ subfolder = subfolder ,
667
+ user_agent = user_agent ,
668
+ )
669
+ state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
670
+ except Exception as e :
671
+ if not allow_pickle :
672
+ raise e
673
+
674
+ model_file = None
675
+
676
+ if model_file is None :
653
677
model_file = _get_model_file (
654
678
pretrained_model_name_or_path ,
655
- weights_name = weight_name or TEXT_INVERSION_NAME_SAFE ,
679
+ weights_name = weight_name or TEXT_INVERSION_NAME ,
656
680
cache_dir = cache_dir ,
657
681
force_download = force_download ,
658
682
resume_download = resume_download ,
@@ -663,28 +687,9 @@ def load_textual_inversion(
663
687
subfolder = subfolder ,
664
688
user_agent = user_agent ,
665
689
)
666
- state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
667
- except Exception as e :
668
- if not allow_pickle :
669
- raise e
670
-
671
- model_file = None
672
-
673
- if model_file is None :
674
- model_file = _get_model_file (
675
- pretrained_model_name_or_path ,
676
- weights_name = weight_name or TEXT_INVERSION_NAME ,
677
- cache_dir = cache_dir ,
678
- force_download = force_download ,
679
- resume_download = resume_download ,
680
- proxies = proxies ,
681
- local_files_only = local_files_only ,
682
- use_auth_token = use_auth_token ,
683
- revision = revision ,
684
- subfolder = subfolder ,
685
- user_agent = user_agent ,
686
- )
687
- state_dict = torch .load (model_file , map_location = "cpu" )
690
+ state_dict = torch .load (model_file , map_location = "cpu" )
691
+ else :
692
+ state_dict = pretrained_model_name_or_path
688
693
689
694
# 2. Load token and embedding correcly from file
690
695
loaded_token = None
0 commit comments