@@ -134,7 +134,7 @@ class AudioPipelineOutput(BaseOutput):
134
134
audios : np .ndarray
135
135
136
136
137
- def is_safetensors_compatible (filenames , variant = None ) -> bool :
137
+ def is_safetensors_compatible (filenames , variant = None , passed_components = None ) -> bool :
138
138
"""
139
139
Checking for safetensors compatibility:
140
140
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
@@ -150,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
150
150
151
151
sf_filenames = set ()
152
152
153
+ passed_components = passed_components or []
154
+
153
155
for filename in filenames :
154
156
_ , extension = os .path .splitext (filename )
155
157
158
+ if len (filename .split ("/" )) == 2 and filename .split ("/" )[0 ] in passed_components :
159
+ continue
160
+
156
161
if extension == ".bin" :
157
162
pt_filenames .append (filename )
158
163
elif extension == ".safetensors" :
@@ -163,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
163
168
path , filename = os .path .split (filename )
164
169
filename , extension = os .path .splitext (filename )
165
170
166
- if filename == "pytorch_model" :
167
- filename = "model"
168
- elif filename == f"pytorch_model.{ variant } " :
169
- filename = f"model.{ variant } "
171
+ if filename .startswith ("pytorch_model" ):
172
+ filename = filename .replace ("pytorch_model" , "model" )
170
173
else :
171
174
filename = filename
172
175
@@ -196,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
196
199
weight_prefixes = [w .split ("." )[0 ] for w in weight_names ]
197
200
# .bin, .safetensors, ...
198
201
weight_suffixs = [w .split ("." )[- 1 ] for w in weight_names ]
202
+ # -00001-of-00002
203
+ transformers_index_format = "\d{5}-of-\d{5}"
204
+
205
+ if variant is not None :
206
+ # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
207
+ variant_file_re = re .compile (
208
+ f"({ '|' .join (weight_prefixes )} )\.({ variant } |{ variant } -{ transformers_index_format } )\.({ '|' .join (weight_suffixs )} )$"
209
+ )
210
+ # `text_encoder/pytorch_model.bin.index.fp16.json`
211
+ variant_index_re = re .compile (
212
+ f"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.{ variant } \.json$"
213
+ )
199
214
200
- variant_file_regex = (
201
- re .compile (f"({ '|' .join (weight_prefixes )} )(.{ variant } .)({ '|' .join (weight_suffixs )} )" )
202
- if variant is not None
203
- else None
215
+ # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
216
+ non_variant_file_re = re .compile (
217
+ f"({ '|' .join (weight_prefixes )} )(-{ transformers_index_format } )?\.({ '|' .join (weight_suffixs )} )$"
204
218
)
205
- non_variant_file_regex = re .compile (f"{ '|' .join (weight_names )} " )
219
+ # `text_encoder/pytorch_model.bin.index.json`
220
+ non_variant_index_re = re .compile (f"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.json" )
206
221
207
222
if variant is not None :
208
- variant_filenames = {f for f in filenames if variant_file_regex .match (f .split ("/" )[- 1 ]) is not None }
223
+ variant_weights = {f for f in filenames if variant_file_re .match (f .split ("/" )[- 1 ]) is not None }
224
+ variant_indexes = {f for f in filenames if variant_index_re .match (f .split ("/" )[- 1 ]) is not None }
225
+ variant_filenames = variant_weights | variant_indexes
209
226
else :
210
227
variant_filenames = set ()
211
228
212
- non_variant_filenames = {f for f in filenames if non_variant_file_regex .match (f .split ("/" )[- 1 ]) is not None }
229
+ non_variant_weights = {f for f in filenames if non_variant_file_re .match (f .split ("/" )[- 1 ]) is not None }
230
+ non_variant_indexes = {f for f in filenames if non_variant_index_re .match (f .split ("/" )[- 1 ]) is not None }
231
+ non_variant_filenames = non_variant_weights | non_variant_indexes
213
232
233
+ # all variant filenames will be used by default
214
234
usable_filenames = set (variant_filenames )
235
+
236
+ def convert_to_variant (filename ):
237
+ if "index" in filename :
238
+ variant_filename = filename .replace ("index" , f"index.{ variant } " )
239
+ elif re .compile (f"^(.*?){ transformers_index_format } " ).match (filename ) is not None :
240
+ variant_filename = f"{ filename .split ('-' )[0 ]} .{ variant } -{ '-' .join (filename .split ('-' )[1 :])} "
241
+ else :
242
+ variant_filename = f"{ filename .split ('.' )[0 ]} .{ variant } .{ filename .split ('.' )[1 ]} "
243
+ return variant_filename
244
+
215
245
for f in non_variant_filenames :
216
- variant_filename = f" { f . split ( '.' )[ 0 ] } . { variant } . { f . split ( '.' )[ 1 ] } "
246
+ variant_filename = convert_to_variant ( f )
217
247
if variant_filename not in usable_filenames :
218
248
usable_filenames .add (f )
219
249
@@ -292,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
292
322
return class_obj , class_candidates
293
323
294
324
325
+ def _get_pipeline_class (class_obj , config , custom_pipeline = None , cache_dir = None , revision = None ):
326
+ if custom_pipeline is not None :
327
+ if custom_pipeline .endswith (".py" ):
328
+ path = Path (custom_pipeline )
329
+ # decompose into folder & file
330
+ file_name = path .name
331
+ custom_pipeline = path .parent .absolute ()
332
+ else :
333
+ file_name = CUSTOM_PIPELINE_FILE_NAME
334
+
335
+ return get_class_from_dynamic_module (
336
+ custom_pipeline , module_file = file_name , cache_dir = cache_dir , revision = revision
337
+ )
338
+
339
+ if class_obj != DiffusionPipeline :
340
+ return class_obj
341
+
342
+ diffusers_module = importlib .import_module (class_obj .__module__ .split ("." )[0 ])
343
+ return getattr (diffusers_module , config ["_class_name" ])
344
+
345
+
295
346
def load_sub_model (
296
347
library_name : str ,
297
348
class_name : str ,
@@ -779,7 +830,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
779
830
device_map = kwargs .pop ("device_map" , None )
780
831
low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
781
832
variant = kwargs .pop ("variant" , None )
782
- kwargs .pop ("use_safetensors" , None if is_safetensors_available () else False )
833
+ use_safetensors = kwargs .pop ("use_safetensors" , None if is_safetensors_available () else False )
783
834
784
835
# 1. Download the checkpoints and configs
785
836
# use snapshot download here to get it working from from_pretrained
@@ -794,8 +845,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
794
845
use_auth_token = use_auth_token ,
795
846
revision = revision ,
796
847
from_flax = from_flax ,
848
+ use_safetensors = use_safetensors ,
797
849
custom_pipeline = custom_pipeline ,
850
+ custom_revision = custom_revision ,
798
851
variant = variant ,
852
+ ** kwargs ,
799
853
)
800
854
else :
801
855
cached_folder = pretrained_model_name_or_path
@@ -810,29 +864,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
810
864
for folder in os .listdir (cached_folder ):
811
865
folder_path = os .path .join (cached_folder , folder )
812
866
is_folder = os .path .isdir (folder_path ) and folder in config_dict
813
- variant_exists = is_folder and any (path .split ("." )[1 ] == variant for path in os .listdir (folder_path ))
867
+ variant_exists = is_folder and any (
868
+ p .split ("." )[1 ].startswith (variant ) for p in os .listdir (folder_path )
869
+ )
814
870
if variant_exists :
815
871
model_variants [folder ] = variant
816
872
817
873
# 3. Load the pipeline class, if using custom module then load it from the hub
818
874
# if we load from explicit class, let's use it
819
- if custom_pipeline is not None :
820
- if custom_pipeline .endswith (".py" ):
821
- path = Path (custom_pipeline )
822
- # decompose into folder & file
823
- file_name = path .name
824
- custom_pipeline = path .parent .absolute ()
825
- else :
826
- file_name = CUSTOM_PIPELINE_FILE_NAME
827
-
828
- pipeline_class = get_class_from_dynamic_module (
829
- custom_pipeline , module_file = file_name , cache_dir = cache_dir , revision = custom_revision
830
- )
831
- elif cls != DiffusionPipeline :
832
- pipeline_class = cls
833
- else :
834
- diffusers_module = importlib .import_module (cls .__module__ .split ("." )[0 ])
835
- pipeline_class = getattr (diffusers_module , config_dict ["_class_name" ])
875
+ pipeline_class = _get_pipeline_class (
876
+ cls , config_dict , custom_pipeline = custom_pipeline , cache_dir = cache_dir , revision = custom_revision
877
+ )
836
878
837
879
# DEPRECATED: To be removed in 1.0.0
838
880
if pipeline_class .__name__ == "StableDiffusionInpaintPipeline" and version .parse (
@@ -1095,6 +1137,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1095
1137
revision = kwargs .pop ("revision" , None )
1096
1138
from_flax = kwargs .pop ("from_flax" , False )
1097
1139
custom_pipeline = kwargs .pop ("custom_pipeline" , None )
1140
+ custom_revision = kwargs .pop ("custom_revision" , None )
1098
1141
variant = kwargs .pop ("variant" , None )
1099
1142
use_safetensors = kwargs .pop ("use_safetensors" , None )
1100
1143
@@ -1153,7 +1196,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1153
1196
# this enables downloading schedulers, tokenizers, ...
1154
1197
allow_patterns += [os .path .join (k , "*" ) for k in folder_names if k not in model_folder_names ]
1155
1198
# also allow downloading config.json files with the model
1156
- allow_patterns += [os .path .join (k , "* .json" ) for k in model_folder_names ]
1199
+ allow_patterns += [os .path .join (k , "config .json" ) for k in model_folder_names ]
1157
1200
1158
1201
allow_patterns += [
1159
1202
SCHEDULER_CONFIG_NAME ,
@@ -1162,17 +1205,28 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1162
1205
CUSTOM_PIPELINE_FILE_NAME ,
1163
1206
]
1164
1207
1208
+ # retrieve passed components that should not be downloaded
1209
+ pipeline_class = _get_pipeline_class (
1210
+ cls , config_dict , custom_pipeline = custom_pipeline , cache_dir = cache_dir , revision = custom_revision
1211
+ )
1212
+ expected_components , _ = cls ._get_signature_keys (pipeline_class )
1213
+ passed_components = [k for k in expected_components if k in kwargs ]
1214
+
1165
1215
if (
1166
1216
use_safetensors
1167
1217
and not allow_pickle
1168
- and not is_safetensors_compatible (model_filenames , variant = variant )
1218
+ and not is_safetensors_compatible (
1219
+ model_filenames , variant = variant , passed_components = passed_components
1220
+ )
1169
1221
):
1170
1222
raise EnvironmentError (
1171
1223
f"Could not found the necessary `safetensors` weights in { model_filenames } (variant={ variant } )"
1172
1224
)
1173
1225
if from_flax :
1174
1226
ignore_patterns = ["*.bin" , "*.safetensors" , "*.onnx" , "*.pb" ]
1175
- elif use_safetensors and is_safetensors_compatible (model_filenames , variant = variant ):
1227
+ elif use_safetensors and is_safetensors_compatible (
1228
+ model_filenames , variant = variant , passed_components = passed_components
1229
+ ):
1176
1230
ignore_patterns = ["*.bin" , "*.msgpack" ]
1177
1231
1178
1232
safetensors_variant_filenames = {f for f in variant_filenames if f .endswith (".safetensors" )}
@@ -1194,6 +1248,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1194
1248
f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n [{ ', ' .join (bin_variant_filenames )} ]\n Loaded non-{ variant } filenames:\n [{ ', ' .join (bin_model_filenames - bin_variant_filenames )} \n If this behavior is not expected, please check your folder structure."
1195
1249
)
1196
1250
1251
+ # Don't download any objects that are passed
1252
+ allow_patterns = [
1253
+ p for p in allow_patterns if not (len (p .split ("/" )) == 2 and p .split ("/" )[0 ] in passed_components )
1254
+ ]
1255
+ # Don't download index files of forbidden patterns either
1256
+ ignore_patterns = ignore_patterns + [f"{ i } .index.*json" for i in ignore_patterns ]
1257
+
1197
1258
re_ignore_pattern = [re .compile (fnmatch .translate (p )) for p in ignore_patterns ]
1198
1259
re_allow_pattern = [re .compile (fnmatch .translate (p )) for p in allow_patterns ]
1199
1260
0 commit comments