File tree 2 files changed +2
-1
lines changed
2 files changed +2
-1
lines changed Original file line number Diff line number Diff line change @@ -121,6 +121,7 @@ def fsdp_main(args):
121
121
device_id = torch .cuda .current_device (),
122
122
limit_all_gathers = fsdp_config .limit_all_gathers )
123
123
124
+ # Enabling this causes https://github.com/pytorch/examples/issues/1210
124
125
if fsdp_config .fsdp_activation_checkpointing :
125
126
policies .apply_fsdp_checkpointing (model )
126
127
Original file line number Diff line number Diff line change @@ -8,7 +8,7 @@ class fsdp_config:
8
8
mixed_precision : bool = True
9
9
use_fp16 : bool = False
10
10
seed : int = 42
11
- fsdp_activation_checkpointing : bool = True
11
+ fsdp_activation_checkpointing : bool = False
12
12
limit_all_gathers : bool = True
13
13
sharding_strategy : ShardingStrategy = ShardingStrategy .FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP
14
14
checkpoint_type : StateDictType = StateDictType .FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs
You can’t perform that action at this time.
0 commit comments