Skip to content

Commit 70862b0

Browse files
committed
[doc] minor fixups to DDP tutorial
Summary: Add "set_device" call to keep things consistent between all DDP tutorials. This was inspired by the following change in the PyTorch repo: pytorch/examples#1285 (review) Test Plan: Ran tutorial with the applied changes and we see: """ Running basic DDP example on rank 3. Running basic DDP example on rank 1. Running basic DDP example on rank 2. Running basic DDP example on rank 0. Finished running basic DDP example on rank 0. Finished running basic DDP example on rank 1. Finished running basic DDP example on rank 3. Finished running basic DDP example on rank 2. Running DDP checkpoint example on rank 2. Running DDP checkpoint example on rank 1. Running DDP checkpoint example on rank 0. Running DDP checkpoint example on rank 3. Finished DDP checkpoint example on rank 0. Finished DDP checkpoint example on rank 3. Finished DDP checkpoint example on rank 1. Finished DDP checkpoint example on rank 2. Running DDP with model parallel example on rank 0. Running DDP with model parallel example on rank 1. Finished running DDP with model parallel example on rank 0. Finished running DDP with model parallel example on rank 1. """
1 parent 904ca90 commit 70862b0

File tree

1 file changed

+51
-37
lines changed

1 file changed

+51
-37
lines changed

intermediate_source/ddp_tutorial.rst

+51-37
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ Getting Started with Distributed Data Parallel
22
=================================================
33
**Author**: `Shen Li <https://mrshenli.github.io/>`_
44

5-
**Edited by**: `Joe Zhu <https://github.com/gunandrose4u>`_
5+
**Edited by**: `Joe Zhu <https://github.com/gunandrose4u>`_, `Chirag Pandya <https://github.com/c-p-i-o>`__
66

77
.. note::
88
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/ddp_tutorial.rst>`__.
@@ -12,27 +12,34 @@ Prerequisites:
1212
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
1313
- `DistributedDataParallel API documents <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
1414
- `DistributedDataParallel notes <https://pytorch.org/docs/master/notes/ddp.html>`__
15+
- The code in this tutorial runs on an 8-GPU server, but it can be easily generalized to other environments.
1516

1617

1718
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#module-torch.nn.parallel>`__
18-
(DDP) implements data parallelism at the module level which can run across
19-
multiple machines. Applications using DDP should spawn multiple processes and
20-
create a single DDP instance per process. DDP uses collective communications in the
19+
(DDP) is a powerful module in PyTorch that allows you to parallelize your model across
20+
multiple machines, making it perfect for large-scale deep learning applications.
21+
To use DDP, you'll need to spawn multiple processes and create a single instance of DDP per process.
22+
23+
But how does it work? DDP uses collective communications from the
2124
`torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
22-
package to synchronize gradients and buffers. More specifically, DDP registers
23-
an autograd hook for each parameter given by ``model.parameters()`` and the
24-
hook will fire when the corresponding gradient is computed in the backward
25-
pass. Then DDP uses that signal to trigger gradient synchronization across
26-
processes. Please refer to
27-
`DDP design note <https://pytorch.org/docs/master/notes/ddp.html>`__ for more details.
25+
package to synchronize gradients and buffers across all processes. This means that each process will have
26+
its own copy of the model, but they'll all work together to train the model as if it were on a single machine.
27+
28+
To make this happen, DDP registers an autograd hook for each parameter in the model.
29+
When the backward pass is run, this hook fires and triggers gradient synchronization across all processes.
30+
This ensures that each process has the same gradients, which are then used to update the model.
31+
32+
For more information on how DDP works and how to use it effectively, be sure to check out the
33+
`DDP design note <https://pytorch.org/docs/master/notes/ddp.html>`__.
34+
With DDP, you can train your models faster and more efficiently than ever before!
35+
36+
The recommended way to use DDP is to spawn one process for each model replica. The model replica can span
37+
multiple devices. DDP processes can be placed on the same machine or across machines. Note that GPU devices
38+
cannot be shared across DDP processes.
2839

2940

30-
The recommended way to use DDP is to spawn one process for each model replica,
31-
where a model replica can span multiple devices. DDP processes can be
32-
placed on the same machine or across machines, but GPU devices cannot be
33-
shared across processes. This tutorial starts from a basic DDP use case and
34-
then demonstrates more advanced use cases including checkpointing models and
35-
combining DDP with model parallel.
41+
In this tutorial, we'll start with a basic DDP use case and then demonstrate more advanced use cases,
42+
including checkpointing models and combining DDP with model parallel.
3643

3744

3845
.. note::
@@ -43,25 +50,23 @@ combining DDP with model parallel.
4350
Comparison between ``DataParallel`` and ``DistributedDataParallel``
4451
-------------------------------------------------------------------
4552

46-
Before we dive in, let's clarify why, despite the added complexity, you would
47-
consider using ``DistributedDataParallel`` over ``DataParallel``:
53+
Before we dive in, let's clarify why you would consider using ``DistributedDataParallel``
54+
over ``DataParallel``, despite its added complexity:
4855

4956
- First, ``DataParallel`` is single-process, multi-thread, and only works on a
50-
single machine, while ``DistributedDataParallel`` is multi-process and works
51-
for both single- and multi- machine training. ``DataParallel`` is usually
52-
slower than ``DistributedDataParallel`` even on a single machine due to GIL
53-
contention across threads, per-iteration replicated model, and additional
54-
overhead introduced by scattering inputs and gathering outputs.
57+
single machine. In contrast, ``DistributedDataParallel`` is multi-process and supports
58+
both single- and multi- machine training.
59+
Due to GIL contention across threads, per-iteration replicated model, and additional overhead introduced by
60+
scattering inputs and gathering outputs, ``DataParallel`` is usually
61+
slower than ``DistributedDataParallel`` even on a single machine.
5562
- Recall from the
5663
`prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
5764
that if your model is too large to fit on a single GPU, you must use **model parallel**
5865
to split it across multiple GPUs. ``DistributedDataParallel`` works with
59-
**model parallel**; ``DataParallel`` does not at this time. When DDP is combined
66+
**model parallel**, while ``DataParallel`` does not at this time. When DDP is combined
6067
with model parallel, each DDP process would use model parallel, and all processes
6168
collectively would use data parallel.
62-
- If your model needs to span multiple machines or if your use case does not fit
63-
into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html>`__
64-
for more generic distributed training support.
69+
6570

6671
Basic Use Case
6772
--------------
@@ -99,6 +104,9 @@ be found in
99104
os.environ['MASTER_ADDR'] = 'localhost'
100105
os.environ['MASTER_PORT'] = '12355'
101106
107+
# set the device id for this process
108+
torch.cuda.set_device(rank)
109+
102110
# initialize the process group
103111
dist.init_process_group("gloo", rank=rank, world_size=world_size)
104112
@@ -141,6 +149,7 @@ different DDP processes starting from different initial model parameter values.
141149
optimizer.step()
142150
143151
cleanup()
152+
print(f"Finished running basic DDP example on rank {rank}.")
144153
145154
146155
def run_demo(demo_fn, world_size):
@@ -149,6 +158,7 @@ different DDP processes starting from different initial model parameter values.
149158
nprocs=world_size,
150159
join=True)
151160
161+
152162
As you can see, DDP wraps lower-level distributed communication details and
153163
provides a clean API as if it were a local model. Gradient synchronization
154164
communications take place during the backward pass and overlap with the
@@ -182,7 +192,7 @@ for more details. When using DDP, one optimization is to save the model in
182192
only one process and then load it to all processes, reducing write overhead.
183193
This is correct because all processes start from the same parameters and
184194
gradients are synchronized in backward passes, and hence optimizers should keep
185-
setting parameters to the same values. If you use this optimization, make sure no process starts
195+
setting parameters to the same values. If you use this optimization, make sure no process starts
186196
loading before the saving is finished. Additionally, when
187197
loading the module, you need to provide an appropriate ``map_location``
188198
argument to prevent a process from stepping into others' devices. If ``map_location``
@@ -218,7 +228,7 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
218228
219229
loss_fn = nn.MSELoss()
220230
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
221-
231+
222232
optimizer.zero_grad()
223233
outputs = ddp_model(torch.randn(20, 10))
224234
labels = torch.randn(20, 5).to(rank)
@@ -234,6 +244,7 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
234244
os.remove(CHECKPOINT_PATH)
235245
236246
cleanup()
247+
print(f"Finished running DDP checkpoint example on rank {rank}.")
237248
238249
Combining DDP with Model Parallelism
239250
------------------------------------
@@ -285,6 +296,7 @@ either the application or the model ``forward()`` method.
285296
optimizer.step()
286297
287298
cleanup()
299+
print(f"Finished running DDP with model parallel example on rank {rank}.")
288300
289301
290302
if __name__ == "__main__":
@@ -325,8 +337,9 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
325337
def demo_basic():
326338
dist.init_process_group("nccl")
327339
rank = dist.get_rank()
340+
328341
print(f"Start running basic DDP example on rank {rank}.")
329-
342+
330343
# create model and move it to GPU with id rank
331344
device_id = rank % torch.cuda.device_count()
332345
model = ToyModel().to(device_id)
@@ -340,23 +353,24 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
340353
labels = torch.randn(20, 5).to(device_id)
341354
loss_fn(outputs, labels).backward()
342355
optimizer.step()
343-
dist.destroy_process_group()
344-
356+
cleanup()
357+
print(f"Finished running basic DDP example on rank {rank}.")
358+
345359
if __name__ == "__main__":
346360
demo_basic()
347361
348-
One can then run a `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command
362+
One can then run a `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command
349363
on all nodes to initialize the DDP job created above:
350364

351365
.. code:: bash
352366
353367
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py
354368
355-
We are running the DDP script on two hosts, and each host we run with 8 processes, aka, we
369+
We are running the DDP script on two hosts, and each host we run with 8 processes, aka, we
356370
are running it on 16 GPUs. Note that ``$MASTER_ADDR`` must be the same across all nodes.
357371

358-
Here torchrun will launch 8 process and invoke ``elastic_ddp.py``
359-
on each process on the node it is launched on, but user also needs to apply cluster
372+
Here torchrun will launch 8 process and invoke ``elastic_ddp.py``
373+
on each process on the node it is launched on, but user also needs to apply cluster
360374
management tools like slurm to actually run this command on 2 nodes.
361375

362376
For example, on a SLURM enabled cluster, we can write a script to run the command above
@@ -371,5 +385,5 @@ Then we can just run this script using the SLURM command: ``srun --nodes=2 ./tor
371385
Of course, this is just an example; you can choose your own cluster scheduling tools
372386
to initiate the torchrun job.
373387

374-
For more information about Elastic run, one can check this
388+
For more information about Elastic run, one can check this
375389
`quick start document <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ to learn more.

0 commit comments

Comments
 (0)