@@ -227,6 +227,16 @@ def __init__(self, cfg: DictConfig) -> None:
227
227
self ._enable_activation_offloading = cfg .get (
228
228
"enable_activation_offloading" , False
229
229
)
230
+ self ._activation_offloading_use_streams = cfg .get (
231
+ "activation_offloading_use_streams" , True
232
+ )
233
+ if self ._activation_offloading_use_streams and self .parallel_dims .tp_enabled :
234
+ warn (
235
+ message = (
236
+ "Using activation offloading with streams is not advised in tensor parallel, and may "
237
+ "cause unstable training. It is advised to set activation_offloading_use_streams: False"
238
+ )
239
+ )
230
240
if self ._enable_activation_offloading :
231
241
if device_type != "cuda" :
232
242
raise RuntimeError (
@@ -339,6 +349,7 @@ def setup(self, cfg: DictConfig) -> None:
339
349
cfg_model = cfg .model ,
340
350
enable_activation_checkpointing = self ._enable_activation_checkpointing ,
341
351
enable_activation_offloading = self ._enable_activation_offloading ,
352
+ activation_offloading_use_streams = self ._activation_offloading_use_streams ,
342
353
custom_sharded_layers = cfg .get ("custom_sharded_layers" , None ),
343
354
fsdp_cpu_offload = self .fsdp_cpu_offload ,
344
355
reshard_after_forward = cfg .get ("fsdp_reshard_after_forward" , True ),
@@ -541,6 +552,7 @@ def _setup_model(
541
552
cfg_model : DictConfig ,
542
553
enable_activation_checkpointing : bool ,
543
554
enable_activation_offloading : bool ,
555
+ activation_offloading_use_streams : bool ,
544
556
fsdp_cpu_offload : bool ,
545
557
reshard_after_forward : bool ,
546
558
model_state_dict : Dict [str , Any ],
@@ -659,7 +671,7 @@ def _setup_model(
659
671
660
672
# activation offloading
661
673
self .activations_handling_ctx = training .get_act_offloading_ctx_manager (
662
- model , enable_activation_offloading
674
+ model , enable_activation_offloading , activation_offloading_use_streams
663
675
)
664
676
665
677
# Ensure no params and buffers are on meta device
0 commit comments