Skip to content

Commit 2f0f281

Browse files
sayakpaulDN6
andauthored
[Tests] restrict memory tests for quanto for certain schemes. (#11052)
* restrict memory tests for quanto for certain schemes. * Apply suggestions from code review Co-authored-by: Dhruv Nair <[email protected]> * fixes * style --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent ccc8321 commit 2f0f281

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
mps_backend_registered = hasattr(torch.backends, "mps")
102102
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
103103

104+
from .torch_utils import get_torch_cuda_device_capability
105+
104106

105107
def torch_all_close(a, b, *args, **kwargs):
106108
if not is_torch_available():
@@ -282,6 +284,20 @@ def require_torch_gpu(test_case):
282284
)
283285

284286

287+
def require_torch_cuda_compatibility(expected_compute_capability):
288+
def decorator(test_case):
289+
if not torch.cuda.is_available():
290+
return unittest.skip(test_case)
291+
else:
292+
current_compute_capability = get_torch_cuda_device_capability()
293+
return unittest.skipUnless(
294+
float(current_compute_capability) == float(expected_compute_capability),
295+
"Test not supported for this compute capability.",
296+
)
297+
298+
return decorator
299+
300+
285301
# These decorators are for accelerator-specific behaviours that are not GPU-specific
286302
def require_torch_accelerator(test_case):
287303
"""Decorator marking a test that requires an accelerator backend and PyTorch."""

tests/quantization/quanto/test_quanto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
numpy_cosine_similarity_distance,
1111
require_accelerate,
1212
require_big_gpu_with_torch_cuda,
13+
require_torch_cuda_compatibility,
1314
torch_device,
1415
)
1516

@@ -311,13 +312,15 @@ def get_dummy_init_kwargs(self):
311312
return {"weights_dtype": "int8"}
312313

313314

315+
@require_torch_cuda_compatibility(8.0)
314316
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
315317
expected_memory_reduction = 0.55
316318

317319
def get_dummy_init_kwargs(self):
318320
return {"weights_dtype": "int4"}
319321

320322

323+
@require_torch_cuda_compatibility(8.0)
321324
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
322325
expected_memory_reduction = 0.65
323326

0 commit comments

Comments
 (0)