You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/FSDP_tutorial.rst
+18-19
Original file line number
Diff line number
Diff line change
@@ -8,7 +8,7 @@ Getting Started with Fully Sharded Data Parallel(FSDP)
8
8
9
9
Training AI models at a large scale is a challenging task that requires a lot of compute power and resources.
10
10
It also comes with considerable engineering complexity to handle the training of these very large models.
11
-
`Pytorch FSDP <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`__, released in PyTorch 1.11 makes this easier.
11
+
`PyTorch FSDP <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`__, released in PyTorch 1.11 makes this easier.
12
12
13
13
In this tutorial, we show how to use `FSDP APIs <https://pytorch.org/docs/1.11/fsdp.html>`__, for simple MNIST models that can be extended to other larger models such as `HuggingFace BERT models <https://huggingface.co/blog/zero-deepspeed-fairscale>`__,
14
14
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ . The sample DDP MNIST code has been borrowed from `here <https://github.com/yqhu/mnist_examples>`__.
@@ -18,7 +18,7 @@ How FSDP works
18
18
--------------
19
19
In `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.
20
20
21
-
FSDPGPU memory footprint would be smaller than DDP across all workers. This makes the training of some very large models feasible and helps to fit larger models or batch sizes for our training job. This would come with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like communication and computation overlapping.
21
+
When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on device. This comes with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like overlapping communication and computation.
@@ -27,7 +27,7 @@ FSDP GPU memory footprint would be smaller than DDP across all workers. This mak
27
27
28
28
FSDP Workflow
29
29
30
-
At high level FSDP works as follow:
30
+
At a high level FSDP works as follow:
31
31
32
32
*In constructor*
33
33
@@ -48,11 +48,11 @@ At high level FSDP works as follow:
48
48
49
49
How to use FSDP
50
50
--------------
51
-
Here we use a toy model to run training on MNIST dataset for demonstration purposes. Similarly the APIs and logic can be applied to larger models for training.
51
+
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.
52
52
53
53
*Setup*
54
54
55
-
1.1 Install Pytorch along with Torchvision
55
+
1.1 Install PyTorch along with Torchvision
56
56
57
57
.. code-block:: bash
58
58
@@ -139,7 +139,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
139
139
output = F.log_softmax(x, dim=1)
140
140
return output
141
141
142
-
2.2 define a train function
142
+
2.2 Define a train function
143
143
144
144
.. code-block:: python
145
145
@@ -189,7 +189,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
189
189
190
190
2.4 Define a distributed train function that wraps the model in FSDP
191
191
192
-
**Note: to save the FSDP model, we need to call the state_dict on each rank then on Rank 0 save the overall states. This is only available in Pytorch nightlies, current Pytorch release is 1.11 at the moment.**
192
+
**Note: to save the FSDP model, we need to call the state_dict on each rank then on Rank 0 save the overall states.**
193
193
194
194
.. code-block:: python
195
195
@@ -250,7 +250,6 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
250
250
if args.save_model:
251
251
# use a barrier to make sure training is done on all ranks
252
252
dist.barrier()
253
-
# state_dict for FSDP model is only available on Nightlies for now
254
253
states = model.state_dict()
255
254
if rank ==0:
256
255
torch.save(states, "mnist_cnn.pt")
@@ -259,7 +258,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
259
258
260
259
261
260
262
-
2.5 Finally parsing the arguments and setting the main function
261
+
2.5 Finally parse the arguments and set the main function
263
262
264
263
.. code-block:: python
265
264
@@ -319,7 +318,7 @@ Alternatively, we will look at adding the fsdp_auto_wrap_policy next and will di
319
318
)
320
319
)
321
320
322
-
Following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 gpus captured from Pytorch Profiler.
321
+
The following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch Profiler.
@@ -329,7 +328,7 @@ Following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarge AW
329
328
330
329
FSDP Peak Memory Usage
331
330
332
-
*Applying fsdp_auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.
331
+
Applying *fsdp_auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.
333
332
The way it works is that, suppose your model contains 100 Linear layers. If you do FSDP(model), there will only be one FSDP unit which wraps the entire model.
334
333
In that case, the allgather would collect the full parameters for all 100 linear layers, and hence won't save CUDA memory for parameter sharding.
335
334
Also, there is only one blocking allgather call for the all 100 linear layers, there will not be communication and computation overlapping between layers.
@@ -354,7 +353,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
354
353
model = FSDP(model,
355
354
fsdp_auto_wrap_policy=my_auto_wrap_policy)
356
355
357
-
Applying the FSDP_auto_wrap_policy, the model would be as follows:
356
+
Applying the fsdp_auto_wrap_policy, the model would be as follows:
358
357
359
358
.. code-block:: bash
360
359
@@ -381,7 +380,7 @@ Applying the FSDP_auto_wrap_policy, the model would be as follows:
381
380
382
381
CUDA event elapsed time on training loop 41.89130859375sec
383
382
384
-
Following is the peak memory usage from FSDP with auto_wrap policy of MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 gpus captured from Pytorch Profiler.
383
+
The following is the peak memory usage from FSDP with auto_wrap policy of MNIST training on a g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch Profiler.
385
384
It can be observed that the peak memory usage on each device is smaller compared to FSDP without auto wrap policy applied, from ~75 MB to 66 MB.
@@ -391,11 +390,11 @@ It can be observed that the peak memory usage on each device is smaller compared
391
390
392
391
FSDP Peak Memory Usage using Auto_wrap policy
393
392
394
-
*CPU Off-loading*: In case the model is very large that even with FSDP wouldn't fit into gpus, then CPU offload can be helpful here.
393
+
*CPU Off-loading*: In case the model is very large that even with FSDP wouldn't fit into GPUs, then CPU offload can be helpful here.
395
394
396
395
Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in cpu_offload=CPUOffload(offload_params=True).
397
396
398
-
Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on the same device to work with the optimizer. This API is subject to change. Default is None in which case there will be no offloading.
397
+
Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on the same device to work with the optimizer. This API is subject to change. The default is None in which case there will be no offloading.
399
398
400
399
Using this feature may slow down the training considerably, due to frequent copying of tensors from host to device, but it could help improve memory efficiency and train larger scale models.
401
400
@@ -409,7 +408,7 @@ In 2.4 we just add it to the FSDP wrapper
409
408
cpu_offload=CPUOffload(offload_params=True))
410
409
411
410
412
-
Compare it with DDP, if in 2.4 we just normally wrap the model in ddp, saving the changes in “DDP_mnist.py”.
411
+
Compare it with DDP, if in 2.4 we just normally wrap the model in DPP, saving the changes in “DDP_mnist.py”.
413
412
414
413
.. code-block:: python
415
414
@@ -423,7 +422,7 @@ Compare it with DDP, if in 2.4 we just normally wrap the model in ddp, saving th
423
422
424
423
CUDA event elapsed time on training loop 39.77766015625sec
425
424
426
-
Following is the peak memory usage from DDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 gpus captured from Pytorch profiler.
425
+
The following is the peak memory usage from DDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch profiler.
@@ -434,8 +433,8 @@ Following is the peak memory usage from DDP MNIST training on g4dn.12.xlarge AWS
434
433
435
434
436
435
Considering the toy example and tiny MNIST model we defined here, we can observe the difference between peak memory usage of DDP and FSDP.
437
-
In DDP each process holds a replica of the model, so the memory footprint is higher compared to FSDP that shards the model parameter, optimizer states and gradients over DDP ranks.
436
+
In DDP each process holds a replica of the model, so the memory footprint is higher compared to FSDP which shards the model parameters, optimizer states and gradients over DDP ranks.
438
437
The peak memory usage using FSDP with auto_wrap policy is the lowest followed by FSDP and DDP.
439
438
440
-
Also, looking at timings, considering the small model and running the training on a single machine, FSDP with/out auto_wrap policy performed almost as fast as DDP.
439
+
Also, looking at timings, considering the small model and running the training on a single machine, FSDP with and without auto_wrap policy performed almost as fast as DDP.
441
440
This example does not represent most of the real applications, for detailed analysis and comparison between DDP and FSDP please refer to this `blog post <https://pytorch.medium.com/6c8da2be180d>`__ .
0 commit comments