Skip to content

Commit 21a7ff1

Browse files
yiyixuxusayakpaul
andauthored
update the logic of is_sequential_cpu_offload (#7788)
* up * add comment to the tests + fix dit --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 8909ab4 commit 21a7ff1

File tree

9 files changed

+123
-24
lines changed

9 files changed

+123
-24
lines changed

examples/community/pipeline_demofusion_sdxl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
13041304
if isinstance(component, torch.nn.Module):
13051305
if hasattr(component, "_hf_hook"):
13061306
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1307-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1307+
is_sequential_cpu_offload = (
1308+
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1309+
or hasattr(component._hf_hook, "hooks")
1310+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
1311+
)
13081312
logger.info(
13091313
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
13101314
)

src/diffusers/loaders/lora.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ def _optionally_disable_offloading(cls, _pipeline):
369369
if not is_model_cpu_offload:
370370
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
371371
if not is_sequential_cpu_offload:
372-
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
372+
is_sequential_cpu_offload = (
373+
isinstance(component._hf_hook, AlignDevicesHook)
374+
or hasattr(component._hf_hook, "hooks")
375+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
376+
)
373377

374378
logger.info(
375379
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."

src/diffusers/loaders/textual_inversion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,11 @@ def load_textual_inversion(
423423
if isinstance(component, nn.Module):
424424
if hasattr(component, "_hf_hook"):
425425
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
426-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
426+
is_sequential_cpu_offload = (
427+
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
428+
or hasattr(component._hf_hook, "hooks")
429+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
430+
)
427431
logger.info(
428432
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
429433
)

src/diffusers/loaders/unet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,11 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
359359
for _, component in _pipeline.components.items():
360360
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
361361
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
362-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
362+
is_sequential_cpu_offload = (
363+
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
364+
or hasattr(component._hf_hook, "hooks")
365+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
366+
)
363367

364368
logger.info(
365369
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."

src/diffusers/pipelines/dit/pipeline_dit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def __call__(
227227
if output_type == "pil":
228228
samples = self.numpy_to_pil(samples)
229229

230+
# Offload all models
231+
self.maybe_free_model_hooks()
232+
230233
if not return_dict:
231234
return (samples,)
232235

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,11 @@ def module_is_sequentially_offloaded(module):
376376
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
377377
return False
378378

379-
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
379+
return hasattr(module, "_hf_hook") and (
380+
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
381+
or hasattr(module._hf_hook, "hooks")
382+
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
383+
)
380384

381385
def module_is_offloaded(module):
382386
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
@@ -1005,8 +1009,7 @@ def remove_all_hooks(self):
10051009
"""
10061010
for _, model in self.components.items():
10071011
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
1008-
is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook)
1009-
accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload)
1012+
accelerate.hooks.remove_hook_from_module(model, recurse=True)
10101013
self._all_hooks = []
10111014

10121015
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):

tests/pipelines/pixart_alpha/test_pixart.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,6 @@ def test_raises_warning_for_mask_feature(self):
324324
def test_inference_batch_single_identical(self):
325325
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
326326

327-
# PixArt transformer model does not work with sequential offload so skip it for now
328-
def test_sequential_offload_forward_pass_twice(self):
329-
pass
330-
331327

332328
@slow
333329
@require_torch_gpu

tests/pipelines/pixart_sigma/test_pixart.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,6 @@ def test_inference_with_multiple_images_per_prompt(self):
308308
def test_inference_batch_single_identical(self):
309309
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
310310

311-
# PixArt transformer model does not work with sequential offload so skip it for now
312-
def test_sequential_offload_forward_pass_twice(self):
313-
pass
314-
315311

316312
@slow
317313
@require_torch_gpu

tests/pipelines/test_pipelines_common.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,8 @@ def _test_attention_slicing_forward_pass(
13601360
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
13611361
)
13621362
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
1363+
import accelerate
1364+
13631365
components = self.get_dummy_components()
13641366
pipe = self.pipeline_class(**components)
13651367
for component in pipe.components.values():
@@ -1373,18 +1375,56 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
13731375
output_without_offload = pipe(**inputs)[0]
13741376

13751377
pipe.enable_sequential_cpu_offload()
1378+
assert pipe._execution_device.type == pipe._offload_device.type
13761379

13771380
inputs = self.get_dummy_inputs(generator_device)
13781381
output_with_offload = pipe(**inputs)[0]
13791382

13801383
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
13811384
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
13821385

1386+
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1387+
offloaded_modules = {
1388+
k: v
1389+
for k, v in pipe.components.items()
1390+
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
1391+
}
1392+
# 1. all offloaded modules should be saved to cpu and moved to meta device
1393+
self.assertTrue(
1394+
all(v.device.type == "meta" for v in offloaded_modules.values()),
1395+
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
1396+
)
1397+
# 2. all offloaded modules should have hook installed
1398+
self.assertTrue(
1399+
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
1400+
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
1401+
)
1402+
# 3. all offloaded modules should have correct hooks installed, should be either one of these two
1403+
# - `AlignDevicesHook`
1404+
# - a SequentialHook` that contains `AlignDevicesHook`
1405+
offloaded_modules_with_incorrect_hooks = {}
1406+
for k, v in offloaded_modules.items():
1407+
if hasattr(v, "_hf_hook"):
1408+
if isinstance(v._hf_hook, accelerate.hooks.SequentialHook):
1409+
# if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
1410+
for hook in v._hf_hook.hooks:
1411+
if not isinstance(hook, accelerate.hooks.AlignDevicesHook):
1412+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0])
1413+
elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
1414+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
1415+
1416+
self.assertTrue(
1417+
len(offloaded_modules_with_incorrect_hooks) == 0,
1418+
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
1419+
)
1420+
13831421
@unittest.skipIf(
13841422
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
13851423
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
13861424
)
13871425
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
1426+
import accelerate
1427+
13881428
generator_device = "cpu"
13891429
components = self.get_dummy_components()
13901430
pipe = self.pipeline_class(**components)
@@ -1400,19 +1440,39 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
14001440
output_without_offload = pipe(**inputs)[0]
14011441

14021442
pipe.enable_model_cpu_offload()
1443+
assert pipe._execution_device.type == pipe._offload_device.type
1444+
14031445
inputs = self.get_dummy_inputs(generator_device)
14041446
output_with_offload = pipe(**inputs)[0]
14051447

14061448
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
14071449
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
1408-
offloaded_modules = [
1409-
v
1450+
1451+
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1452+
offloaded_modules = {
1453+
k: v
14101454
for k, v in pipe.components.items()
14111455
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
1412-
]
1413-
(
1414-
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
1415-
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
1456+
}
1457+
# 1. check if all offloaded modules are saved to cpu
1458+
self.assertTrue(
1459+
all(v.device.type == "cpu" for v in offloaded_modules.values()),
1460+
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
1461+
)
1462+
# 2. check if all offloaded modules have hooks installed
1463+
self.assertTrue(
1464+
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
1465+
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
1466+
)
1467+
# 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
1468+
offloaded_modules_with_incorrect_hooks = {}
1469+
for k, v in offloaded_modules.items():
1470+
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
1471+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
1472+
1473+
self.assertTrue(
1474+
len(offloaded_modules_with_incorrect_hooks) == 0,
1475+
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
14161476
)
14171477

14181478
@unittest.skipIf(
@@ -1444,16 +1504,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
14441504
self.assertLess(
14451505
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
14461506
)
1507+
1508+
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
14471509
offloaded_modules = {
14481510
k: v
14491511
for k, v in pipe.components.items()
14501512
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
14511513
}
1514+
# 1. check if all offloaded modules are saved to cpu
14521515
self.assertTrue(
14531516
all(v.device.type == "cpu" for v in offloaded_modules.values()),
14541517
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
14551518
)
1456-
1519+
# 2. check if all offloaded modules have hooks installed
1520+
self.assertTrue(
1521+
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
1522+
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
1523+
)
1524+
# 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
14571525
offloaded_modules_with_incorrect_hooks = {}
14581526
for k, v in offloaded_modules.items():
14591527
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
@@ -1493,19 +1561,36 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
14931561
self.assertLess(
14941562
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
14951563
)
1564+
1565+
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
14961566
offloaded_modules = {
14971567
k: v
14981568
for k, v in pipe.components.items()
14991569
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
15001570
}
1571+
# 1. check if all offloaded modules are moved to meta device
15011572
self.assertTrue(
15021573
all(v.device.type == "meta" for v in offloaded_modules.values()),
15031574
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
15041575
)
1576+
# 2. check if all offloaded modules have hook installed
1577+
self.assertTrue(
1578+
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
1579+
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
1580+
)
1581+
# 3. check if all offloaded modules have correct hooks installed, should be either one of these two
1582+
# - `AlignDevicesHook`
1583+
# - a SequentialHook` that contains `AlignDevicesHook`
15051584
offloaded_modules_with_incorrect_hooks = {}
15061585
for k, v in offloaded_modules.items():
1507-
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
1508-
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
1586+
if hasattr(v, "_hf_hook"):
1587+
if isinstance(v._hf_hook, accelerate.hooks.SequentialHook):
1588+
# if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
1589+
for hook in v._hf_hook.hooks:
1590+
if not isinstance(hook, accelerate.hooks.AlignDevicesHook):
1591+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0])
1592+
elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
1593+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
15091594

15101595
self.assertTrue(
15111596
len(offloaded_modules_with_incorrect_hooks) == 0,

0 commit comments

Comments
 (0)