@@ -90,10 +90,15 @@ def onload_(self):
90
90
91
91
with context :
92
92
for group_module in self .modules :
93
- group_module .to ( self . onload_device , non_blocking = self . non_blocking )
94
- if self .record_stream :
95
- for param in group_module . parameters () :
93
+ for param in group_module .parameters ():
94
+ param . data = param . data . to ( self .onload_device , non_blocking = self . non_blocking )
95
+ if self . record_stream :
96
96
param .data .record_stream (current_stream )
97
+ for buffer in group_module .buffers ():
98
+ buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
99
+ if self .record_stream :
100
+ buffer .data .record_stream (current_stream )
101
+
97
102
if self .parameters is not None :
98
103
for param in self .parameters :
99
104
param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
@@ -113,6 +118,12 @@ def offload_(self):
113
118
for group_module in self .modules :
114
119
for param in group_module .parameters ():
115
120
param .data = self .cpu_param_dict [param ]
121
+ if self .parameters is not None :
122
+ for param in self .parameters :
123
+ param .data = self .cpu_param_dict [param ]
124
+ if self .buffers is not None :
125
+ for buffer in self .buffers :
126
+ buffer .data = self .cpu_param_dict [buffer ]
116
127
else :
117
128
for group_module in self .modules :
118
129
group_module .to (self .offload_device , non_blocking = self .non_blocking )
@@ -406,9 +417,7 @@ def _apply_group_offloading_block_level(
406
417
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
407
418
cpu_param_dict = None
408
419
if stream is not None :
409
- for param in module .parameters ():
410
- param .data = param .data .cpu ().pin_memory ()
411
- cpu_param_dict = {param : param .data for param in module .parameters ()}
420
+ cpu_param_dict = _get_pinned_cpu_param_dict (module )
412
421
413
422
# Create module groups for ModuleList and Sequential blocks
414
423
modules_with_group_offloading = set ()
@@ -509,9 +518,7 @@ def _apply_group_offloading_leaf_level(
509
518
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
510
519
cpu_param_dict = None
511
520
if stream is not None :
512
- for param in module .parameters ():
513
- param .data = param .data .cpu ().pin_memory ()
514
- cpu_param_dict = {param : param .data for param in module .parameters ()}
521
+ cpu_param_dict = _get_pinned_cpu_param_dict (module )
515
522
516
523
# Create module groups for leaf modules and apply group offloading hooks
517
524
modules_with_group_offloading = set ()
@@ -630,6 +637,17 @@ def _apply_lazy_group_offloading_hook(
630
637
registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
631
638
632
639
640
+ def _get_pinned_cpu_param_dict (module : torch .nn .Module ) -> Dict [torch .nn .Parameter , torch .Tensor ]:
641
+ cpu_param_dict = {}
642
+ for param in module .parameters ():
643
+ param .data = param .data .cpu ().pin_memory ()
644
+ cpu_param_dict [param ] = param .data
645
+ for buffer in module .buffers ():
646
+ buffer .data = buffer .data .cpu ().pin_memory ()
647
+ cpu_param_dict [buffer ] = buffer .data
648
+ return cpu_param_dict
649
+
650
+
633
651
def _gather_parameters_with_no_group_offloading_parent (
634
652
module : torch .nn .Module , modules_with_group_offloading : Set [str ]
635
653
) -> List [torch .nn .Parameter ]:
0 commit comments