-
Notifications
You must be signed in to change notification settings - Fork 6k
[feat] implement record_stream
when using CUDA streams during group offloading
#11081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
ffce2d1
f25ea18
2a28f6d
41ea4c8
f5b69b0
9281e84
637f84e
612136f
d5afea5
fb59f36
4a6eeba
87a93fe
1d4ca61
535dcd1
2ff9112
b4deedc
622aba7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ def __init__( | |
buffers: Optional[List[torch.Tensor]] = None, | ||
non_blocking: bool = False, | ||
stream: Optional[torch.cuda.Stream] = None, | ||
record_stream: Optional[bool] = False, | ||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, | ||
onload_self: bool = True, | ||
) -> None: | ||
|
@@ -68,33 +69,47 @@ def __init__( | |
self.buffers = buffers | ||
self.non_blocking = non_blocking or stream is not None | ||
self.stream = stream | ||
self.record_stream = record_stream | ||
self.cpu_param_dict = cpu_param_dict | ||
self.onload_self = onload_self | ||
|
||
if self.stream is not None and self.cpu_param_dict is None: | ||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") | ||
raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.") | ||
|
||
if self.record_stream and not self.stream: | ||
raise ValueError("`record_stream` cannot be True when `stream` is None.") | ||
|
||
def onload_(self): | ||
r"""Onloads the group of modules to the onload_device.""" | ||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) | ||
current_stream = torch.cuda.current_stream() if self.record_stream else None | ||
|
||
if self.stream is not None: | ||
# Wait for previous Host->Device transfer to complete | ||
self.stream.synchronize() | ||
|
||
with context: | ||
for group_module in self.modules: | ||
group_module.to(self.onload_device, non_blocking=self.non_blocking) | ||
if self.record_stream: | ||
for param in group_module.parameters(): | ||
param.data.record_stream(current_stream) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current_stream should be self.stream here I think. We need to tell pytorch that the param.data and buffer.data here is owned by the non-default stream. Currently, we're telling it that it is owned by the default stream, which seems incorrect to me Sorry for the back and forth but I think we will have to run the benchmark once more with the change 😅 Apart from this, everything else looks good. We can button up the docs and merge after @DN6 gives a look. Let's make sure to mention that this may use more memory in comparison to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh nevermind, ignore my comment.
We don't create anything on the non-default stream, so torch.cuda.current_stream is correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @a-r-r-o-w no problem. Better to be rigorous and double-check everything. Just ran the benchmark after merging {"record_stream": false, "memory": "1.3514", "time": "32.792"}
{"record_stream": true, "memory": "1.3514", "time": "30.944"} Feel free to run it yourself if you want.
Absolutely. I will mention in the docstrings. Is there any other place you wanted me to mention it? |
||
if self.parameters is not None: | ||
for param in self.parameters: | ||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) | ||
if self.record_stream: | ||
param.data.record_stream(current_stream) | ||
if self.buffers is not None: | ||
for buffer in self.buffers: | ||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) | ||
if self.record_stream: | ||
buffer.data.record_stream(current_stream) | ||
|
||
def offload_(self): | ||
r"""Offloads the group of modules to the offload_device.""" | ||
if self.stream is not None: | ||
torch.cuda.current_stream().synchronize() | ||
if not self.record_stream: | ||
torch.cuda.current_stream().synchronize() | ||
for group_module in self.modules: | ||
for param in group_module.parameters(): | ||
param.data = self.cpu_param_dict[param] | ||
|
@@ -268,6 +283,7 @@ def apply_group_offloading( | |
num_blocks_per_group: Optional[int] = None, | ||
non_blocking: bool = False, | ||
use_stream: bool = False, | ||
record_stream: bool = False, | ||
) -> None: | ||
r""" | ||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and | ||
|
@@ -314,6 +330,7 @@ def apply_group_offloading( | |
use_stream (`bool`, defaults to `False`): | ||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for | ||
overlapping computation and data transfer. | ||
record_stream: TODO | ||
|
||
Example: | ||
```python | ||
|
@@ -349,10 +366,10 @@ def apply_group_offloading( | |
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") | ||
|
||
_apply_group_offloading_block_level( | ||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream | ||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, record_stream | ||
) | ||
elif offload_type == "leaf_level": | ||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) | ||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, record_stream) | ||
else: | ||
raise ValueError(f"Unsupported offload_type: {offload_type}") | ||
|
||
|
@@ -364,6 +381,7 @@ def _apply_group_offloading_block_level( | |
onload_device: torch.device, | ||
non_blocking: bool, | ||
stream: Optional[torch.cuda.Stream] = None, | ||
record_stream: Optional[bool] = False, | ||
) -> None: | ||
r""" | ||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to | ||
|
@@ -382,6 +400,7 @@ def _apply_group_offloading_block_level( | |
stream (`torch.cuda.Stream`, *optional*): | ||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful | ||
for overlapping computation and data transfer. | ||
record_stream: TODO | ||
""" | ||
|
||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used | ||
|
@@ -411,6 +430,7 @@ def _apply_group_offloading_block_level( | |
onload_leader=current_modules[0], | ||
non_blocking=non_blocking, | ||
stream=stream, | ||
record_stream=record_stream, | ||
cpu_param_dict=cpu_param_dict, | ||
onload_self=stream is None, | ||
) | ||
|
@@ -448,6 +468,7 @@ def _apply_group_offloading_block_level( | |
buffers=buffers, | ||
non_blocking=False, | ||
stream=None, | ||
record_stream=False, | ||
cpu_param_dict=None, | ||
onload_self=True, | ||
) | ||
|
@@ -461,6 +482,7 @@ def _apply_group_offloading_leaf_level( | |
onload_device: torch.device, | ||
non_blocking: bool, | ||
stream: Optional[torch.cuda.Stream] = None, | ||
record_stream: Optional[bool] = False, | ||
) -> None: | ||
r""" | ||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory | ||
|
@@ -481,6 +503,7 @@ def _apply_group_offloading_leaf_level( | |
stream (`torch.cuda.Stream`, *optional*): | ||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful | ||
for overlapping computation and data transfer. | ||
record_stream: TODO | ||
""" | ||
|
||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used | ||
|
@@ -503,6 +526,7 @@ def _apply_group_offloading_leaf_level( | |
onload_leader=submodule, | ||
non_blocking=non_blocking, | ||
stream=stream, | ||
record_stream=record_stream, | ||
cpu_param_dict=cpu_param_dict, | ||
onload_self=True, | ||
) | ||
|
@@ -548,6 +572,7 @@ def _apply_group_offloading_leaf_level( | |
buffers=buffers, | ||
non_blocking=non_blocking, | ||
stream=stream, | ||
record_stream=record_stream, | ||
cpu_param_dict=cpu_param_dict, | ||
onload_self=True, | ||
) | ||
|
@@ -567,6 +592,7 @@ def _apply_group_offloading_leaf_level( | |
buffers=None, | ||
non_blocking=False, | ||
stream=None, | ||
record_stream=False, | ||
cpu_param_dict=None, | ||
onload_self=True, | ||
) | ||
|
Uh oh!
There was an error while loading. Please reload this page.