Skip to content

Commit 5428046

Browse files
authored
[Refactor] Clean up import utils boilerplate (#11026)
* update * update * update
1 parent e7ffeae commit 5428046

File tree

1 file changed

+63
-236
lines changed

1 file changed

+63
-236
lines changed

src/diffusers/utils/import_utils.py

Lines changed: 63 additions & 236 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import Any, Union
2626

2727
from huggingface_hub.utils import is_jinja_available # noqa: F401
28-
from packaging import version
2928
from packaging.version import Version, parse
3029

3130
from . import logging
@@ -52,36 +51,30 @@
5251

5352
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
5453

55-
_torch_version = "N/A"
56-
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
57-
_torch_available = importlib.util.find_spec("torch") is not None
58-
if _torch_available:
54+
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
55+
56+
57+
def _is_package_available(pkg_name: str):
58+
pkg_exists = importlib.util.find_spec(pkg_name) is not None
59+
pkg_version = "N/A"
60+
61+
if pkg_exists:
5962
try:
60-
_torch_version = importlib_metadata.version("torch")
61-
logger.info(f"PyTorch version {_torch_version} available.")
62-
except importlib_metadata.PackageNotFoundError:
63-
_torch_available = False
63+
pkg_version = importlib_metadata.version(pkg_name)
64+
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
65+
except (ImportError, importlib_metadata.PackageNotFoundError):
66+
pkg_exists = False
67+
68+
return pkg_exists, pkg_version
69+
70+
71+
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
72+
_torch_available, _torch_version = _is_package_available("torch")
73+
6474
else:
6575
logger.info("Disabling PyTorch because USE_TORCH is set")
6676
_torch_available = False
6777

68-
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
69-
if _torch_xla_available:
70-
try:
71-
_torch_xla_version = importlib_metadata.version("torch_xla")
72-
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
73-
except ImportError:
74-
_torch_xla_available = False
75-
76-
# check whether torch_npu is available
77-
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
78-
if _torch_npu_available:
79-
try:
80-
_torch_npu_version = importlib_metadata.version("torch_npu")
81-
logger.info(f"torch_npu version {_torch_npu_version} available.")
82-
except ImportError:
83-
_torch_npu_available = False
84-
8578
_jax_version = "N/A"
8679
_flax_version = "N/A"
8780
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
@@ -97,47 +90,12 @@
9790
_flax_available = False
9891

9992
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
100-
_safetensors_available = importlib.util.find_spec("safetensors") is not None
101-
if _safetensors_available:
102-
try:
103-
_safetensors_version = importlib_metadata.version("safetensors")
104-
logger.info(f"Safetensors version {_safetensors_version} available.")
105-
except importlib_metadata.PackageNotFoundError:
106-
_safetensors_available = False
93+
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
94+
10795
else:
10896
logger.info("Disabling Safetensors because USE_TF is set")
10997
_safetensors_available = False
11098

111-
_transformers_available = importlib.util.find_spec("transformers") is not None
112-
try:
113-
_transformers_version = importlib_metadata.version("transformers")
114-
logger.debug(f"Successfully imported transformers version {_transformers_version}")
115-
except importlib_metadata.PackageNotFoundError:
116-
_transformers_available = False
117-
118-
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
119-
try:
120-
_hf_hub_version = importlib_metadata.version("huggingface_hub")
121-
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
122-
except importlib_metadata.PackageNotFoundError:
123-
_hf_hub_available = False
124-
125-
126-
_inflect_available = importlib.util.find_spec("inflect") is not None
127-
try:
128-
_inflect_version = importlib_metadata.version("inflect")
129-
logger.debug(f"Successfully imported inflect version {_inflect_version}")
130-
except importlib_metadata.PackageNotFoundError:
131-
_inflect_available = False
132-
133-
134-
_unidecode_available = importlib.util.find_spec("unidecode") is not None
135-
try:
136-
_unidecode_version = importlib_metadata.version("unidecode")
137-
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
138-
except importlib_metadata.PackageNotFoundError:
139-
_unidecode_available = False
140-
14199
_onnxruntime_version = "N/A"
142100
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
143101
if _onnx_available:
@@ -186,85 +144,6 @@
186144
except importlib_metadata.PackageNotFoundError:
187145
_opencv_available = False
188146

189-
_scipy_available = importlib.util.find_spec("scipy") is not None
190-
try:
191-
_scipy_version = importlib_metadata.version("scipy")
192-
logger.debug(f"Successfully imported scipy version {_scipy_version}")
193-
except importlib_metadata.PackageNotFoundError:
194-
_scipy_available = False
195-
196-
_librosa_available = importlib.util.find_spec("librosa") is not None
197-
try:
198-
_librosa_version = importlib_metadata.version("librosa")
199-
logger.debug(f"Successfully imported librosa version {_librosa_version}")
200-
except importlib_metadata.PackageNotFoundError:
201-
_librosa_available = False
202-
203-
_accelerate_available = importlib.util.find_spec("accelerate") is not None
204-
try:
205-
_accelerate_version = importlib_metadata.version("accelerate")
206-
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
207-
except importlib_metadata.PackageNotFoundError:
208-
_accelerate_available = False
209-
210-
_xformers_available = importlib.util.find_spec("xformers") is not None
211-
try:
212-
_xformers_version = importlib_metadata.version("xformers")
213-
if _torch_available:
214-
_torch_version = importlib_metadata.version("torch")
215-
if version.Version(_torch_version) < version.Version("1.12"):
216-
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
217-
218-
logger.debug(f"Successfully imported xformers version {_xformers_version}")
219-
except importlib_metadata.PackageNotFoundError:
220-
_xformers_available = False
221-
222-
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
223-
try:
224-
_k_diffusion_version = importlib_metadata.version("k_diffusion")
225-
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
226-
except importlib_metadata.PackageNotFoundError:
227-
_k_diffusion_available = False
228-
229-
_note_seq_available = importlib.util.find_spec("note_seq") is not None
230-
try:
231-
_note_seq_version = importlib_metadata.version("note_seq")
232-
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
233-
except importlib_metadata.PackageNotFoundError:
234-
_note_seq_available = False
235-
236-
_wandb_available = importlib.util.find_spec("wandb") is not None
237-
try:
238-
_wandb_version = importlib_metadata.version("wandb")
239-
logger.debug(f"Successfully imported wandb version {_wandb_version }")
240-
except importlib_metadata.PackageNotFoundError:
241-
_wandb_available = False
242-
243-
244-
_tensorboard_available = importlib.util.find_spec("tensorboard")
245-
try:
246-
_tensorboard_version = importlib_metadata.version("tensorboard")
247-
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
248-
except importlib_metadata.PackageNotFoundError:
249-
_tensorboard_available = False
250-
251-
252-
_compel_available = importlib.util.find_spec("compel")
253-
try:
254-
_compel_version = importlib_metadata.version("compel")
255-
logger.debug(f"Successfully imported compel version {_compel_version}")
256-
except importlib_metadata.PackageNotFoundError:
257-
_compel_available = False
258-
259-
260-
_ftfy_available = importlib.util.find_spec("ftfy") is not None
261-
try:
262-
_ftfy_version = importlib_metadata.version("ftfy")
263-
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
264-
except importlib_metadata.PackageNotFoundError:
265-
_ftfy_available = False
266-
267-
268147
_bs4_available = importlib.util.find_spec("bs4") is not None
269148
try:
270149
# importlib metadata under different name
@@ -273,105 +152,49 @@
273152
except importlib_metadata.PackageNotFoundError:
274153
_bs4_available = False
275154

276-
_torchsde_available = importlib.util.find_spec("torchsde") is not None
277-
try:
278-
_torchsde_version = importlib_metadata.version("torchsde")
279-
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
280-
except importlib_metadata.PackageNotFoundError:
281-
_torchsde_available = False
282-
283155
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
284156
try:
285157
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
286158
logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
287159
except importlib_metadata.PackageNotFoundError:
288160
_invisible_watermark_available = False
289161

290-
291-
_peft_available = importlib.util.find_spec("peft") is not None
292-
try:
293-
_peft_version = importlib_metadata.version("peft")
294-
logger.debug(f"Successfully imported peft version {_peft_version}")
295-
except importlib_metadata.PackageNotFoundError:
296-
_peft_available = False
297-
298-
_torchvision_available = importlib.util.find_spec("torchvision") is not None
299-
try:
300-
_torchvision_version = importlib_metadata.version("torchvision")
301-
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
302-
except importlib_metadata.PackageNotFoundError:
303-
_torchvision_available = False
304-
305-
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
306-
try:
307-
_sentencepiece_version = importlib_metadata.version("sentencepiece")
308-
logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
309-
except importlib_metadata.PackageNotFoundError:
310-
_sentencepiece_available = False
311-
312-
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
313-
try:
314-
_matplotlib_version = importlib_metadata.version("matplotlib")
315-
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
316-
except importlib_metadata.PackageNotFoundError:
317-
_matplotlib_available = False
318-
319-
_timm_available = importlib.util.find_spec("timm") is not None
320-
if _timm_available:
321-
try:
322-
_timm_version = importlib_metadata.version("timm")
323-
logger.info(f"Timm version {_timm_version} available.")
324-
except importlib_metadata.PackageNotFoundError:
325-
_timm_available = False
326-
327-
328-
def is_timm_available():
329-
return _timm_available
330-
331-
332-
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
333-
try:
334-
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
335-
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
336-
except importlib_metadata.PackageNotFoundError:
337-
_bitsandbytes_available = False
338-
339-
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
340-
341-
_imageio_available = importlib.util.find_spec("imageio") is not None
342-
if _imageio_available:
343-
try:
344-
_imageio_version = importlib_metadata.version("imageio")
345-
logger.debug(f"Successfully imported imageio version {_imageio_version}")
346-
347-
except importlib_metadata.PackageNotFoundError:
348-
_imageio_available = False
349-
350-
_is_gguf_available = importlib.util.find_spec("gguf") is not None
351-
if _is_gguf_available:
352-
try:
353-
_gguf_version = importlib_metadata.version("gguf")
354-
logger.debug(f"Successfully import gguf version {_gguf_version}")
355-
except importlib_metadata.PackageNotFoundError:
356-
_is_gguf_available = False
357-
358-
359-
_is_torchao_available = importlib.util.find_spec("torchao") is not None
360-
if _is_torchao_available:
361-
try:
362-
_torchao_version = importlib_metadata.version("torchao")
363-
logger.debug(f"Successfully import torchao version {_torchao_version}")
364-
except importlib_metadata.PackageNotFoundError:
365-
_is_torchao_available = False
366-
367-
368-
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
369-
if _is_optimum_quanto_available:
162+
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
163+
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
164+
_transformers_available, _transformers_version = _is_package_available("transformers")
165+
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
166+
_inflect_available, _inflect_version = _is_package_available("inflect")
167+
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
168+
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
169+
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
170+
_wandb_available, _wandb_version = _is_package_available("wandb")
171+
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
172+
_compel_available, _compel_version = _is_package_available("compel")
173+
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
174+
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
175+
_peft_available, _peft_version = _is_package_available("peft")
176+
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
177+
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
178+
_timm_available, _timm_version = _is_package_available("timm")
179+
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
180+
_imageio_available, _imageio_version = _is_package_available("imageio")
181+
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
182+
_scipy_available, _scipy_version = _is_package_available("scipy")
183+
_librosa_available, _librosa_version = _is_package_available("librosa")
184+
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
185+
_xformers_available, _xformers_version = _is_package_available("xformers")
186+
_gguf_available, _gguf_version = _is_package_available("gguf")
187+
_torchao_available, _torchao_version = _is_package_available("torchao")
188+
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
189+
_torchao_available, _torchao_version = _is_package_available("torchao")
190+
191+
_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
192+
if _optimum_quanto_available:
370193
try:
371194
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
372195
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
373196
except importlib_metadata.PackageNotFoundError:
374-
_is_optimum_quanto_available = False
197+
_optimum_quanto_available = False
375198

376199

377200
def is_torch_available():
@@ -495,15 +318,19 @@ def is_imageio_available():
495318

496319

497320
def is_gguf_available():
498-
return _is_gguf_available
321+
return _gguf_available
499322

500323

501324
def is_torchao_available():
502-
return _is_torchao_available
325+
return _torchao_available
503326

504327

505328
def is_optimum_quanto_available():
506-
return _is_optimum_quanto_available
329+
return _optimum_quanto_available
330+
331+
332+
def is_timm_available():
333+
return _timm_available
507334

508335

509336
# docstyle-ignore
@@ -863,7 +690,7 @@ def is_gguf_version(operation: str, version: str):
863690
version (`str`):
864691
A version string
865692
"""
866-
if not _is_gguf_available:
693+
if not _gguf_available:
867694
return False
868695
return compare_versions(parse(_gguf_version), operation, version)
869696

@@ -878,7 +705,7 @@ def is_torchao_version(operation: str, version: str):
878705
version (`str`):
879706
A version string
880707
"""
881-
if not _is_torchao_available:
708+
if not _torchao_available:
882709
return False
883710
return compare_versions(parse(_torchao_version), operation, version)
884711

@@ -908,7 +735,7 @@ def is_optimum_quanto_version(operation: str, version: str):
908735
version (`str`):
909736
A version string
910737
"""
911-
if not _is_optimum_quanto_available:
738+
if not _optimum_quanto_available:
912739
return False
913740
return compare_versions(parse(_optimum_quanto_version), operation, version)
914741

0 commit comments

Comments
 (0)