Skip to content

Commit 26de419

Browse files
authored
Fix AC in T5 example (#1273)
1 parent a38cbfc commit 26de419

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

distributed/FSDP/T5_training.py

+1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def fsdp_main(args):
121121
device_id=torch.cuda.current_device(),
122122
limit_all_gathers=fsdp_config.limit_all_gathers)
123123

124+
# Enabling this causes https://github.com/pytorch/examples/issues/1210
124125
if fsdp_config.fsdp_activation_checkpointing:
125126
policies.apply_fsdp_checkpointing(model)
126127

distributed/FSDP/configs/fsdp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class fsdp_config:
88
mixed_precision: bool=True
99
use_fp16: bool=False
1010
seed: int=42
11-
fsdp_activation_checkpointing: bool=True
11+
fsdp_activation_checkpointing: bool=False
1212
limit_all_gathers: bool=True
1313
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP
1414
checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs

0 commit comments

Comments
 (0)