Skip to content

Commit 41ea4c8

Browse files
committed
merge #11097
1 parent 2a28f6d commit 41ea4c8

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,15 @@ def onload_(self):
9090

9191
with context:
9292
for group_module in self.modules:
93-
group_module.to(self.onload_device, non_blocking=self.non_blocking)
94-
if self.record_stream:
95-
for param in group_module.parameters():
93+
for param in group_module.parameters():
94+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
95+
if self.record_stream:
9696
param.data.record_stream(current_stream)
97+
for buffer in group_module.buffers():
98+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
99+
if self.record_stream:
100+
buffer.data.record_stream(current_stream)
101+
97102
if self.parameters is not None:
98103
for param in self.parameters:
99104
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
@@ -113,6 +118,12 @@ def offload_(self):
113118
for group_module in self.modules:
114119
for param in group_module.parameters():
115120
param.data = self.cpu_param_dict[param]
121+
if self.parameters is not None:
122+
for param in self.parameters:
123+
param.data = self.cpu_param_dict[param]
124+
if self.buffers is not None:
125+
for buffer in self.buffers:
126+
buffer.data = self.cpu_param_dict[buffer]
116127
else:
117128
for group_module in self.modules:
118129
group_module.to(self.offload_device, non_blocking=self.non_blocking)
@@ -406,9 +417,7 @@ def _apply_group_offloading_block_level(
406417
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
407418
cpu_param_dict = None
408419
if stream is not None:
409-
for param in module.parameters():
410-
param.data = param.data.cpu().pin_memory()
411-
cpu_param_dict = {param: param.data for param in module.parameters()}
420+
cpu_param_dict = _get_pinned_cpu_param_dict(module)
412421

413422
# Create module groups for ModuleList and Sequential blocks
414423
modules_with_group_offloading = set()
@@ -509,9 +518,7 @@ def _apply_group_offloading_leaf_level(
509518
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
510519
cpu_param_dict = None
511520
if stream is not None:
512-
for param in module.parameters():
513-
param.data = param.data.cpu().pin_memory()
514-
cpu_param_dict = {param: param.data for param in module.parameters()}
521+
cpu_param_dict = _get_pinned_cpu_param_dict(module)
515522

516523
# Create module groups for leaf modules and apply group offloading hooks
517524
modules_with_group_offloading = set()
@@ -630,6 +637,17 @@ def _apply_lazy_group_offloading_hook(
630637
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
631638

632639

640+
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
641+
cpu_param_dict = {}
642+
for param in module.parameters():
643+
param.data = param.data.cpu().pin_memory()
644+
cpu_param_dict[param] = param.data
645+
for buffer in module.buffers():
646+
buffer.data = buffer.data.cpu().pin_memory()
647+
cpu_param_dict[buffer] = buffer.data
648+
return cpu_param_dict
649+
650+
633651
def _gather_parameters_with_no_group_offloading_parent(
634652
module: torch.nn.Module, modules_with_group_offloading: Set[str]
635653
) -> List[torch.nn.Parameter]:

0 commit comments

Comments
 (0)