25
25
from typing import Any , Union
26
26
27
27
from huggingface_hub .utils import is_jinja_available # noqa: F401
28
- from packaging import version
29
28
from packaging .version import Version , parse
30
29
31
30
from . import logging
52
51
53
52
STR_OPERATION_TO_FUNC = {">" : op .gt , ">=" : op .ge , "==" : op .eq , "!=" : op .ne , "<=" : op .le , "<" : op .lt }
54
53
55
- _torch_version = "N/A"
56
- if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES :
57
- _torch_available = importlib .util .find_spec ("torch" ) is not None
58
- if _torch_available :
54
+ _is_google_colab = "google.colab" in sys .modules or any (k .startswith ("COLAB_" ) for k in os .environ )
55
+
56
+
57
+ def _is_package_available (pkg_name : str ):
58
+ pkg_exists = importlib .util .find_spec (pkg_name ) is not None
59
+ pkg_version = "N/A"
60
+
61
+ if pkg_exists :
59
62
try :
60
- _torch_version = importlib_metadata .version ("torch" )
61
- logger .info (f"PyTorch version { _torch_version } available." )
62
- except importlib_metadata .PackageNotFoundError :
63
- _torch_available = False
63
+ pkg_version = importlib_metadata .version (pkg_name )
64
+ logger .debug (f"Successfully imported { pkg_name } version { pkg_version } " )
65
+ except (ImportError , importlib_metadata .PackageNotFoundError ):
66
+ pkg_exists = False
67
+
68
+ return pkg_exists , pkg_version
69
+
70
+
71
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES :
72
+ _torch_available , _torch_version = _is_package_available ("torch" )
73
+
64
74
else :
65
75
logger .info ("Disabling PyTorch because USE_TORCH is set" )
66
76
_torch_available = False
67
77
68
- _torch_xla_available = importlib .util .find_spec ("torch_xla" ) is not None
69
- if _torch_xla_available :
70
- try :
71
- _torch_xla_version = importlib_metadata .version ("torch_xla" )
72
- logger .info (f"PyTorch XLA version { _torch_xla_version } available." )
73
- except ImportError :
74
- _torch_xla_available = False
75
-
76
- # check whether torch_npu is available
77
- _torch_npu_available = importlib .util .find_spec ("torch_npu" ) is not None
78
- if _torch_npu_available :
79
- try :
80
- _torch_npu_version = importlib_metadata .version ("torch_npu" )
81
- logger .info (f"torch_npu version { _torch_npu_version } available." )
82
- except ImportError :
83
- _torch_npu_available = False
84
-
85
78
_jax_version = "N/A"
86
79
_flax_version = "N/A"
87
80
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES :
97
90
_flax_available = False
98
91
99
92
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES :
100
- _safetensors_available = importlib .util .find_spec ("safetensors" ) is not None
101
- if _safetensors_available :
102
- try :
103
- _safetensors_version = importlib_metadata .version ("safetensors" )
104
- logger .info (f"Safetensors version { _safetensors_version } available." )
105
- except importlib_metadata .PackageNotFoundError :
106
- _safetensors_available = False
93
+ _safetensors_available , _safetensors_version = _is_package_available ("safetensors" )
94
+
107
95
else :
108
96
logger .info ("Disabling Safetensors because USE_TF is set" )
109
97
_safetensors_available = False
110
98
111
- _transformers_available = importlib .util .find_spec ("transformers" ) is not None
112
- try :
113
- _transformers_version = importlib_metadata .version ("transformers" )
114
- logger .debug (f"Successfully imported transformers version { _transformers_version } " )
115
- except importlib_metadata .PackageNotFoundError :
116
- _transformers_available = False
117
-
118
- _hf_hub_available = importlib .util .find_spec ("huggingface_hub" ) is not None
119
- try :
120
- _hf_hub_version = importlib_metadata .version ("huggingface_hub" )
121
- logger .debug (f"Successfully imported huggingface_hub version { _hf_hub_version } " )
122
- except importlib_metadata .PackageNotFoundError :
123
- _hf_hub_available = False
124
-
125
-
126
- _inflect_available = importlib .util .find_spec ("inflect" ) is not None
127
- try :
128
- _inflect_version = importlib_metadata .version ("inflect" )
129
- logger .debug (f"Successfully imported inflect version { _inflect_version } " )
130
- except importlib_metadata .PackageNotFoundError :
131
- _inflect_available = False
132
-
133
-
134
- _unidecode_available = importlib .util .find_spec ("unidecode" ) is not None
135
- try :
136
- _unidecode_version = importlib_metadata .version ("unidecode" )
137
- logger .debug (f"Successfully imported unidecode version { _unidecode_version } " )
138
- except importlib_metadata .PackageNotFoundError :
139
- _unidecode_available = False
140
-
141
99
_onnxruntime_version = "N/A"
142
100
_onnx_available = importlib .util .find_spec ("onnxruntime" ) is not None
143
101
if _onnx_available :
186
144
except importlib_metadata .PackageNotFoundError :
187
145
_opencv_available = False
188
146
189
- _scipy_available = importlib .util .find_spec ("scipy" ) is not None
190
- try :
191
- _scipy_version = importlib_metadata .version ("scipy" )
192
- logger .debug (f"Successfully imported scipy version { _scipy_version } " )
193
- except importlib_metadata .PackageNotFoundError :
194
- _scipy_available = False
195
-
196
- _librosa_available = importlib .util .find_spec ("librosa" ) is not None
197
- try :
198
- _librosa_version = importlib_metadata .version ("librosa" )
199
- logger .debug (f"Successfully imported librosa version { _librosa_version } " )
200
- except importlib_metadata .PackageNotFoundError :
201
- _librosa_available = False
202
-
203
- _accelerate_available = importlib .util .find_spec ("accelerate" ) is not None
204
- try :
205
- _accelerate_version = importlib_metadata .version ("accelerate" )
206
- logger .debug (f"Successfully imported accelerate version { _accelerate_version } " )
207
- except importlib_metadata .PackageNotFoundError :
208
- _accelerate_available = False
209
-
210
- _xformers_available = importlib .util .find_spec ("xformers" ) is not None
211
- try :
212
- _xformers_version = importlib_metadata .version ("xformers" )
213
- if _torch_available :
214
- _torch_version = importlib_metadata .version ("torch" )
215
- if version .Version (_torch_version ) < version .Version ("1.12" ):
216
- raise ValueError ("xformers is installed in your environment and requires PyTorch >= 1.12" )
217
-
218
- logger .debug (f"Successfully imported xformers version { _xformers_version } " )
219
- except importlib_metadata .PackageNotFoundError :
220
- _xformers_available = False
221
-
222
- _k_diffusion_available = importlib .util .find_spec ("k_diffusion" ) is not None
223
- try :
224
- _k_diffusion_version = importlib_metadata .version ("k_diffusion" )
225
- logger .debug (f"Successfully imported k-diffusion version { _k_diffusion_version } " )
226
- except importlib_metadata .PackageNotFoundError :
227
- _k_diffusion_available = False
228
-
229
- _note_seq_available = importlib .util .find_spec ("note_seq" ) is not None
230
- try :
231
- _note_seq_version = importlib_metadata .version ("note_seq" )
232
- logger .debug (f"Successfully imported note-seq version { _note_seq_version } " )
233
- except importlib_metadata .PackageNotFoundError :
234
- _note_seq_available = False
235
-
236
- _wandb_available = importlib .util .find_spec ("wandb" ) is not None
237
- try :
238
- _wandb_version = importlib_metadata .version ("wandb" )
239
- logger .debug (f"Successfully imported wandb version { _wandb_version } " )
240
- except importlib_metadata .PackageNotFoundError :
241
- _wandb_available = False
242
-
243
-
244
- _tensorboard_available = importlib .util .find_spec ("tensorboard" )
245
- try :
246
- _tensorboard_version = importlib_metadata .version ("tensorboard" )
247
- logger .debug (f"Successfully imported tensorboard version { _tensorboard_version } " )
248
- except importlib_metadata .PackageNotFoundError :
249
- _tensorboard_available = False
250
-
251
-
252
- _compel_available = importlib .util .find_spec ("compel" )
253
- try :
254
- _compel_version = importlib_metadata .version ("compel" )
255
- logger .debug (f"Successfully imported compel version { _compel_version } " )
256
- except importlib_metadata .PackageNotFoundError :
257
- _compel_available = False
258
-
259
-
260
- _ftfy_available = importlib .util .find_spec ("ftfy" ) is not None
261
- try :
262
- _ftfy_version = importlib_metadata .version ("ftfy" )
263
- logger .debug (f"Successfully imported ftfy version { _ftfy_version } " )
264
- except importlib_metadata .PackageNotFoundError :
265
- _ftfy_available = False
266
-
267
-
268
147
_bs4_available = importlib .util .find_spec ("bs4" ) is not None
269
148
try :
270
149
# importlib metadata under different name
273
152
except importlib_metadata .PackageNotFoundError :
274
153
_bs4_available = False
275
154
276
- _torchsde_available = importlib .util .find_spec ("torchsde" ) is not None
277
- try :
278
- _torchsde_version = importlib_metadata .version ("torchsde" )
279
- logger .debug (f"Successfully imported torchsde version { _torchsde_version } " )
280
- except importlib_metadata .PackageNotFoundError :
281
- _torchsde_available = False
282
-
283
155
_invisible_watermark_available = importlib .util .find_spec ("imwatermark" ) is not None
284
156
try :
285
157
_invisible_watermark_version = importlib_metadata .version ("invisible-watermark" )
286
158
logger .debug (f"Successfully imported invisible-watermark version { _invisible_watermark_version } " )
287
159
except importlib_metadata .PackageNotFoundError :
288
160
_invisible_watermark_available = False
289
161
290
-
291
- _peft_available = importlib .util .find_spec ("peft" ) is not None
292
- try :
293
- _peft_version = importlib_metadata .version ("peft" )
294
- logger .debug (f"Successfully imported peft version { _peft_version } " )
295
- except importlib_metadata .PackageNotFoundError :
296
- _peft_available = False
297
-
298
- _torchvision_available = importlib .util .find_spec ("torchvision" ) is not None
299
- try :
300
- _torchvision_version = importlib_metadata .version ("torchvision" )
301
- logger .debug (f"Successfully imported torchvision version { _torchvision_version } " )
302
- except importlib_metadata .PackageNotFoundError :
303
- _torchvision_available = False
304
-
305
- _sentencepiece_available = importlib .util .find_spec ("sentencepiece" ) is not None
306
- try :
307
- _sentencepiece_version = importlib_metadata .version ("sentencepiece" )
308
- logger .info (f"Successfully imported sentencepiece version { _sentencepiece_version } " )
309
- except importlib_metadata .PackageNotFoundError :
310
- _sentencepiece_available = False
311
-
312
- _matplotlib_available = importlib .util .find_spec ("matplotlib" ) is not None
313
- try :
314
- _matplotlib_version = importlib_metadata .version ("matplotlib" )
315
- logger .debug (f"Successfully imported matplotlib version { _matplotlib_version } " )
316
- except importlib_metadata .PackageNotFoundError :
317
- _matplotlib_available = False
318
-
319
- _timm_available = importlib .util .find_spec ("timm" ) is not None
320
- if _timm_available :
321
- try :
322
- _timm_version = importlib_metadata .version ("timm" )
323
- logger .info (f"Timm version { _timm_version } available." )
324
- except importlib_metadata .PackageNotFoundError :
325
- _timm_available = False
326
-
327
-
328
- def is_timm_available ():
329
- return _timm_available
330
-
331
-
332
- _bitsandbytes_available = importlib .util .find_spec ("bitsandbytes" ) is not None
333
- try :
334
- _bitsandbytes_version = importlib_metadata .version ("bitsandbytes" )
335
- logger .debug (f"Successfully imported bitsandbytes version { _bitsandbytes_version } " )
336
- except importlib_metadata .PackageNotFoundError :
337
- _bitsandbytes_available = False
338
-
339
- _is_google_colab = "google.colab" in sys .modules or any (k .startswith ("COLAB_" ) for k in os .environ )
340
-
341
- _imageio_available = importlib .util .find_spec ("imageio" ) is not None
342
- if _imageio_available :
343
- try :
344
- _imageio_version = importlib_metadata .version ("imageio" )
345
- logger .debug (f"Successfully imported imageio version { _imageio_version } " )
346
-
347
- except importlib_metadata .PackageNotFoundError :
348
- _imageio_available = False
349
-
350
- _is_gguf_available = importlib .util .find_spec ("gguf" ) is not None
351
- if _is_gguf_available :
352
- try :
353
- _gguf_version = importlib_metadata .version ("gguf" )
354
- logger .debug (f"Successfully import gguf version { _gguf_version } " )
355
- except importlib_metadata .PackageNotFoundError :
356
- _is_gguf_available = False
357
-
358
-
359
- _is_torchao_available = importlib .util .find_spec ("torchao" ) is not None
360
- if _is_torchao_available :
361
- try :
362
- _torchao_version = importlib_metadata .version ("torchao" )
363
- logger .debug (f"Successfully import torchao version { _torchao_version } " )
364
- except importlib_metadata .PackageNotFoundError :
365
- _is_torchao_available = False
366
-
367
-
368
- _is_optimum_quanto_available = importlib .util .find_spec ("optimum" ) is not None
369
- if _is_optimum_quanto_available :
162
+ _torch_xla_available , _torch_xla_version = _is_package_available ("torch_xla" )
163
+ _torch_npu_available , _torch_npu_version = _is_package_available ("torch_npu" )
164
+ _transformers_available , _transformers_version = _is_package_available ("transformers" )
165
+ _hf_hub_available , _hf_hub_version = _is_package_available ("huggingface_hub" )
166
+ _inflect_available , _inflect_version = _is_package_available ("inflect" )
167
+ _unidecode_available , _unidecode_version = _is_package_available ("unidecode" )
168
+ _k_diffusion_available , _k_diffusion_version = _is_package_available ("k_diffusion" )
169
+ _note_seq_available , _note_seq_version = _is_package_available ("note_seq" )
170
+ _wandb_available , _wandb_version = _is_package_available ("wandb" )
171
+ _tensorboard_available , _tensorboard_version = _is_package_available ("tensorboard" )
172
+ _compel_available , _compel_version = _is_package_available ("compel" )
173
+ _sentencepiece_available , _sentencepiece_version = _is_package_available ("sentencepiece" )
174
+ _torchsde_available , _torchsde_version = _is_package_available ("torchsde" )
175
+ _peft_available , _peft_version = _is_package_available ("peft" )
176
+ _torchvision_available , _torchvision_version = _is_package_available ("torchvision" )
177
+ _matplotlib_available , _matplotlib_version = _is_package_available ("matplotlib" )
178
+ _timm_available , _timm_version = _is_package_available ("timm" )
179
+ _bitsandbytes_available , _bitsandbytes_version = _is_package_available ("bitsandbytes" )
180
+ _imageio_available , _imageio_version = _is_package_available ("imageio" )
181
+ _ftfy_available , _ftfy_version = _is_package_available ("ftfy" )
182
+ _scipy_available , _scipy_version = _is_package_available ("scipy" )
183
+ _librosa_available , _librosa_version = _is_package_available ("librosa" )
184
+ _accelerate_available , _accelerate_version = _is_package_available ("accelerate" )
185
+ _xformers_available , _xformers_version = _is_package_available ("xformers" )
186
+ _gguf_available , _gguf_version = _is_package_available ("gguf" )
187
+ _torchao_available , _torchao_version = _is_package_available ("torchao" )
188
+ _bitsandbytes_available , _bitsandbytes_version = _is_package_available ("bitsandbytes" )
189
+ _torchao_available , _torchao_version = _is_package_available ("torchao" )
190
+
191
+ _optimum_quanto_available = importlib .util .find_spec ("optimum" ) is not None
192
+ if _optimum_quanto_available :
370
193
try :
371
194
_optimum_quanto_version = importlib_metadata .version ("optimum_quanto" )
372
195
logger .debug (f"Successfully import optimum-quanto version { _optimum_quanto_version } " )
373
196
except importlib_metadata .PackageNotFoundError :
374
- _is_optimum_quanto_available = False
197
+ _optimum_quanto_available = False
375
198
376
199
377
200
def is_torch_available ():
@@ -495,15 +318,19 @@ def is_imageio_available():
495
318
496
319
497
320
def is_gguf_available ():
498
- return _is_gguf_available
321
+ return _gguf_available
499
322
500
323
501
324
def is_torchao_available ():
502
- return _is_torchao_available
325
+ return _torchao_available
503
326
504
327
505
328
def is_optimum_quanto_available ():
506
- return _is_optimum_quanto_available
329
+ return _optimum_quanto_available
330
+
331
+
332
+ def is_timm_available ():
333
+ return _timm_available
507
334
508
335
509
336
# docstyle-ignore
@@ -863,7 +690,7 @@ def is_gguf_version(operation: str, version: str):
863
690
version (`str`):
864
691
A version string
865
692
"""
866
- if not _is_gguf_available :
693
+ if not _gguf_available :
867
694
return False
868
695
return compare_versions (parse (_gguf_version ), operation , version )
869
696
@@ -878,7 +705,7 @@ def is_torchao_version(operation: str, version: str):
878
705
version (`str`):
879
706
A version string
880
707
"""
881
- if not _is_torchao_available :
708
+ if not _torchao_available :
882
709
return False
883
710
return compare_versions (parse (_torchao_version ), operation , version )
884
711
@@ -908,7 +735,7 @@ def is_optimum_quanto_version(operation: str, version: str):
908
735
version (`str`):
909
736
A version string
910
737
"""
911
- if not _is_optimum_quanto_available :
738
+ if not _optimum_quanto_available :
912
739
return False
913
740
return compare_versions (parse (_optimum_quanto_version ), operation , version )
914
741
0 commit comments