File tree 2 files changed +19
-0
lines changed
tests/quantization/quanto
2 files changed +19
-0
lines changed Original file line number Diff line number Diff line change 101
101
mps_backend_registered = hasattr (torch .backends , "mps" )
102
102
torch_device = "mps" if (mps_backend_registered and torch .backends .mps .is_available ()) else torch_device
103
103
104
+ from .torch_utils import get_torch_cuda_device_capability
105
+
104
106
105
107
def torch_all_close (a , b , * args , ** kwargs ):
106
108
if not is_torch_available ():
@@ -282,6 +284,20 @@ def require_torch_gpu(test_case):
282
284
)
283
285
284
286
287
+ def require_torch_cuda_compatibility (expected_compute_capability ):
288
+ def decorator (test_case ):
289
+ if not torch .cuda .is_available ():
290
+ return unittest .skip (test_case )
291
+ else :
292
+ current_compute_capability = get_torch_cuda_device_capability ()
293
+ return unittest .skipUnless (
294
+ float (current_compute_capability ) == float (expected_compute_capability ),
295
+ "Test not supported for this compute capability." ,
296
+ )
297
+
298
+ return decorator
299
+
300
+
285
301
# These decorators are for accelerator-specific behaviours that are not GPU-specific
286
302
def require_torch_accelerator (test_case ):
287
303
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
Original file line number Diff line number Diff line change 10
10
numpy_cosine_similarity_distance ,
11
11
require_accelerate ,
12
12
require_big_gpu_with_torch_cuda ,
13
+ require_torch_cuda_compatibility ,
13
14
torch_device ,
14
15
)
15
16
@@ -311,13 +312,15 @@ def get_dummy_init_kwargs(self):
311
312
return {"weights_dtype" : "int8" }
312
313
313
314
315
+ @require_torch_cuda_compatibility (8.0 )
314
316
class FluxTransformerInt4WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
315
317
expected_memory_reduction = 0.55
316
318
317
319
def get_dummy_init_kwargs (self ):
318
320
return {"weights_dtype" : "int4" }
319
321
320
322
323
+ @require_torch_cuda_compatibility (8.0 )
321
324
class FluxTransformerInt2WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
322
325
expected_memory_reduction = 0.65
323
326
You can’t perform that action at this time.
0 commit comments