@@ -1360,6 +1360,8 @@ def _test_attention_slicing_forward_pass(
1360
1360
reason = "CPU offload is only available with CUDA and `accelerate v0.14.0` or higher" ,
1361
1361
)
1362
1362
def test_sequential_cpu_offload_forward_pass (self , expected_max_diff = 1e-4 ):
1363
+ import accelerate
1364
+
1363
1365
components = self .get_dummy_components ()
1364
1366
pipe = self .pipeline_class (** components )
1365
1367
for component in pipe .components .values ():
@@ -1373,18 +1375,56 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
1373
1375
output_without_offload = pipe (** inputs )[0 ]
1374
1376
1375
1377
pipe .enable_sequential_cpu_offload ()
1378
+ assert pipe ._execution_device .type == pipe ._offload_device .type
1376
1379
1377
1380
inputs = self .get_dummy_inputs (generator_device )
1378
1381
output_with_offload = pipe (** inputs )[0 ]
1379
1382
1380
1383
max_diff = np .abs (to_np (output_with_offload ) - to_np (output_without_offload )).max ()
1381
1384
self .assertLess (max_diff , expected_max_diff , "CPU offloading should not affect the inference results" )
1382
1385
1386
+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1387
+ offloaded_modules = {
1388
+ k : v
1389
+ for k , v in pipe .components .items ()
1390
+ if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1391
+ }
1392
+ # 1. all offloaded modules should be saved to cpu and moved to meta device
1393
+ self .assertTrue (
1394
+ all (v .device .type == "meta" for v in offloaded_modules .values ()),
1395
+ f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'meta' ]} " ,
1396
+ )
1397
+ # 2. all offloaded modules should have hook installed
1398
+ self .assertTrue (
1399
+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1400
+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1401
+ )
1402
+ # 3. all offloaded modules should have correct hooks installed, should be either one of these two
1403
+ # - `AlignDevicesHook`
1404
+ # - a SequentialHook` that contains `AlignDevicesHook`
1405
+ offloaded_modules_with_incorrect_hooks = {}
1406
+ for k , v in offloaded_modules .items ():
1407
+ if hasattr (v , "_hf_hook" ):
1408
+ if isinstance (v ._hf_hook , accelerate .hooks .SequentialHook ):
1409
+ # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
1410
+ for hook in v ._hf_hook .hooks :
1411
+ if not isinstance (hook , accelerate .hooks .AlignDevicesHook ):
1412
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook .hooks [0 ])
1413
+ elif not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1414
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1415
+
1416
+ self .assertTrue (
1417
+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1418
+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
1419
+ )
1420
+
1383
1421
@unittest .skipIf (
1384
1422
torch_device != "cuda" or not is_accelerate_available () or is_accelerate_version ("<" , "0.17.0" ),
1385
1423
reason = "CPU offload is only available with CUDA and `accelerate v0.17.0` or higher" ,
1386
1424
)
1387
1425
def test_model_cpu_offload_forward_pass (self , expected_max_diff = 2e-4 ):
1426
+ import accelerate
1427
+
1388
1428
generator_device = "cpu"
1389
1429
components = self .get_dummy_components ()
1390
1430
pipe = self .pipeline_class (** components )
@@ -1400,19 +1440,39 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
1400
1440
output_without_offload = pipe (** inputs )[0 ]
1401
1441
1402
1442
pipe .enable_model_cpu_offload ()
1443
+ assert pipe ._execution_device .type == pipe ._offload_device .type
1444
+
1403
1445
inputs = self .get_dummy_inputs (generator_device )
1404
1446
output_with_offload = pipe (** inputs )[0 ]
1405
1447
1406
1448
max_diff = np .abs (to_np (output_with_offload ) - to_np (output_without_offload )).max ()
1407
1449
self .assertLess (max_diff , expected_max_diff , "CPU offloading should not affect the inference results" )
1408
- offloaded_modules = [
1409
- v
1450
+
1451
+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1452
+ offloaded_modules = {
1453
+ k : v
1410
1454
for k , v in pipe .components .items ()
1411
1455
if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1412
- ]
1413
- (
1414
- self .assertTrue (all (v .device .type == "cpu" for v in offloaded_modules )),
1415
- f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'cpu' ]} " ,
1456
+ }
1457
+ # 1. check if all offloaded modules are saved to cpu
1458
+ self .assertTrue (
1459
+ all (v .device .type == "cpu" for v in offloaded_modules .values ()),
1460
+ f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'cpu' ]} " ,
1461
+ )
1462
+ # 2. check if all offloaded modules have hooks installed
1463
+ self .assertTrue (
1464
+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1465
+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1466
+ )
1467
+ # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
1468
+ offloaded_modules_with_incorrect_hooks = {}
1469
+ for k , v in offloaded_modules .items ():
1470
+ if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .CpuOffload ):
1471
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1472
+
1473
+ self .assertTrue (
1474
+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1475
+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
1416
1476
)
1417
1477
1418
1478
@unittest .skipIf (
@@ -1444,16 +1504,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
1444
1504
self .assertLess (
1445
1505
max_diff , expected_max_diff , "running CPU offloading 2nd time should not affect the inference results"
1446
1506
)
1507
+
1508
+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1447
1509
offloaded_modules = {
1448
1510
k : v
1449
1511
for k , v in pipe .components .items ()
1450
1512
if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1451
1513
}
1514
+ # 1. check if all offloaded modules are saved to cpu
1452
1515
self .assertTrue (
1453
1516
all (v .device .type == "cpu" for v in offloaded_modules .values ()),
1454
1517
f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'cpu' ]} " ,
1455
1518
)
1456
-
1519
+ # 2. check if all offloaded modules have hooks installed
1520
+ self .assertTrue (
1521
+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1522
+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1523
+ )
1524
+ # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
1457
1525
offloaded_modules_with_incorrect_hooks = {}
1458
1526
for k , v in offloaded_modules .items ():
1459
1527
if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .CpuOffload ):
@@ -1493,19 +1561,36 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
1493
1561
self .assertLess (
1494
1562
max_diff , expected_max_diff , "running sequential offloading second time should have the inference results"
1495
1563
)
1564
+
1565
+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1496
1566
offloaded_modules = {
1497
1567
k : v
1498
1568
for k , v in pipe .components .items ()
1499
1569
if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1500
1570
}
1571
+ # 1. check if all offloaded modules are moved to meta device
1501
1572
self .assertTrue (
1502
1573
all (v .device .type == "meta" for v in offloaded_modules .values ()),
1503
1574
f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'meta' ]} " ,
1504
1575
)
1576
+ # 2. check if all offloaded modules have hook installed
1577
+ self .assertTrue (
1578
+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1579
+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1580
+ )
1581
+ # 3. check if all offloaded modules have correct hooks installed, should be either one of these two
1582
+ # - `AlignDevicesHook`
1583
+ # - a SequentialHook` that contains `AlignDevicesHook`
1505
1584
offloaded_modules_with_incorrect_hooks = {}
1506
1585
for k , v in offloaded_modules .items ():
1507
- if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1508
- offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1586
+ if hasattr (v , "_hf_hook" ):
1587
+ if isinstance (v ._hf_hook , accelerate .hooks .SequentialHook ):
1588
+ # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
1589
+ for hook in v ._hf_hook .hooks :
1590
+ if not isinstance (hook , accelerate .hooks .AlignDevicesHook ):
1591
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook .hooks [0 ])
1592
+ elif not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1593
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1509
1594
1510
1595
self .assertTrue (
1511
1596
len (offloaded_modules_with_incorrect_hooks ) == 0 ,
0 commit comments