Skip to content

Commit c3c532a

Browse files
authored
Allow disabling activation offloading streams in full finetune recipe (#2710)
Signed-off-by: Nathan Azrak <[email protected]>
1 parent 137eec3 commit c3c532a

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

recipes/full_finetune_distributed.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,16 @@ def __init__(self, cfg: DictConfig) -> None:
227227
self._enable_activation_offloading = cfg.get(
228228
"enable_activation_offloading", False
229229
)
230+
self._activation_offloading_use_streams = cfg.get(
231+
"activation_offloading_use_streams", True
232+
)
233+
if self._activation_offloading_use_streams and self.parallel_dims.tp_enabled:
234+
warn(
235+
message=(
236+
"Using activation offloading with streams is not advised in tensor parallel, and may "
237+
"cause unstable training. It is advised to set activation_offloading_use_streams: False"
238+
)
239+
)
230240
if self._enable_activation_offloading:
231241
if device_type != "cuda":
232242
raise RuntimeError(
@@ -339,6 +349,7 @@ def setup(self, cfg: DictConfig) -> None:
339349
cfg_model=cfg.model,
340350
enable_activation_checkpointing=self._enable_activation_checkpointing,
341351
enable_activation_offloading=self._enable_activation_offloading,
352+
activation_offloading_use_streams=self._activation_offloading_use_streams,
342353
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
343354
fsdp_cpu_offload=self.fsdp_cpu_offload,
344355
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
@@ -541,6 +552,7 @@ def _setup_model(
541552
cfg_model: DictConfig,
542553
enable_activation_checkpointing: bool,
543554
enable_activation_offloading: bool,
555+
activation_offloading_use_streams: bool,
544556
fsdp_cpu_offload: bool,
545557
reshard_after_forward: bool,
546558
model_state_dict: Dict[str, Any],
@@ -659,7 +671,7 @@ def _setup_model(
659671

660672
# activation offloading
661673
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
662-
model, enable_activation_offloading
674+
model, enable_activation_offloading, activation_offloading_use_streams
663675
)
664676

665677
# Ensure no params and buffers are on meta device

torchtune/training/_activation_offloading.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def noop(tensor):
378378

379379

380380
def get_act_offloading_ctx_manager(
381-
model: nn.Module, enable_activation_offloading: bool
381+
model: nn.Module, enable_activation_offloading: bool, use_streams: bool = True
382382
) -> Union[OffloadActivations, contextlib.nullcontext]:
383383
"""Returns the activation offloading context manager for the model, which will be
384384
a null context if enable_activation_offloading is False.
@@ -390,6 +390,7 @@ def get_act_offloading_ctx_manager(
390390
model (nn.Module): the model to wrap with the activation offloading context manager.
391391
enable_activation_offloading (bool): whether or not to enable activation offloading
392392
for the model.
393+
use_streams (bool): whether or not to enable streams for overlapping communication.
393394
394395
Returns:
395396
contextlib.ContextDecorator: the activation offloading context manager for the model.
@@ -398,7 +399,7 @@ def get_act_offloading_ctx_manager(
398399
NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
399400
"""
400401
if enable_activation_offloading:
401-
activations_handling_ctx = OffloadActivations()
402+
activations_handling_ctx = OffloadActivations(use_streams=use_streams)
402403

403404
# Below is our hack to disable offloading the last output Linear in every
404405
# step, as the cost for offloading the activation and then soon after bringing

0 commit comments

Comments
 (0)