|
| 1 | +`Introduction <ddp_series_intro.html>`__ \|\| `What is DDP <ddp_series_theory.html>`__ \|\| `Single-Node |
| 2 | +Multi-GPU Training <ddp_series_multigpu.html>`__ \|\| **Fault |
| 3 | +Tolerance** \|\| `Multi-Node |
| 4 | +training <../intermediate/ddp_series_multinode.html>`__ \|\| `minGPT Training <../intermediate/ddp_series_minGPT.html>`__ |
| 5 | + |
| 6 | + |
| 7 | +Fault-tolerant Distributed Training with ``torchrun`` |
| 8 | +===================================================== |
| 9 | + |
| 10 | +Authors: `Suraj Subramanian <https://github.com/suraj813>`__ |
| 11 | + |
| 12 | +.. grid:: 2 |
| 13 | + |
| 14 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 15 | + :margin: 0 |
| 16 | + |
| 17 | + - Launching multi-GPU training jobs with ``torchrun`` |
| 18 | + - Saving and loading snapshots of your training job |
| 19 | + - Structuring your training script for graceful restarts |
| 20 | + |
| 21 | + .. grid:: 1 |
| 22 | + |
| 23 | + .. grid-item:: |
| 24 | + |
| 25 | + :octicon:`code-square;1.0em;` View the code used in this tutorial on `GitHub <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py>`__ |
| 26 | + |
| 27 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 28 | + :margin: 0 |
| 29 | + |
| 30 | + * High-level `overview <ddp_series_theory.html>`__ of DDP |
| 31 | + * Familiarity with `DDP code <ddp_series_multigpu.html>`__ |
| 32 | + * A machine with multiple GPUs (this tutorial uses an AWS p3.8xlarge instance) |
| 33 | + * PyTorch `installed <https://pytorch.org/get-started/locally/>`__ with CUDA |
| 34 | + |
| 35 | +Follow along with the video below or on `youtube <https://www.youtube.com/watch/9kIvQOiwYzg>`__. |
| 36 | + |
| 37 | +.. raw:: html |
| 38 | + |
| 39 | + <div style="margin-top:10px; margin-bottom:10px;"> |
| 40 | + <iframe width="560" height="315" src="https://www.youtube.com/embed/9kIvQOiwYzg" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> |
| 41 | + </div> |
| 42 | + |
| 43 | +In distributed training, a single process failure can |
| 44 | +disrupt the entire training job. Since the susceptibility for failure can be higher here, making your training |
| 45 | +script robust is particularly important here. You might also prefer your training job to be *elastic* i.e. |
| 46 | + |
| 47 | + |
| 48 | +PyTorch offers a utility called ``torchrun`` that provides fault-tolerance and |
| 49 | +elastic training. When a failure occurs, ``torchrun`` logs the errors and |
| 50 | +attempts to automatically restart all the processes from the last saved |
| 51 | +“snapshot” of the training job. |
| 52 | + |
| 53 | +The snapshot saves more than just the model state; it can include |
| 54 | +details about the number of epochs run, optimizer states or any other |
| 55 | +stateful attribute of the training job necessary for its continuity. |
| 56 | + |
| 57 | +Why use ``torchrun`` |
| 58 | +~~~~~~~~~~~~~~~~~~~~ |
| 59 | + |
| 60 | +``torchrun`` handles the minutiae of distributed training so that you |
| 61 | +don't need to. For instance, |
| 62 | + |
| 63 | +- You don't need to set environment variables or explicitly pass the ``rank`` and ``world_size``; torchrun assigns this along with several other `environment variables <https://pytorch.org/docs/stable/elastic/run.html#environment-variables>`__. |
| 64 | +- No need to call ``mp.spawn`` in your script; you only need a generic ``main()`` entrypoint, and launch the script with ``torchrun``. This way the same script can be run in non-distributed as well as single-node and multinode setups. |
| 65 | +- Gracefully restarting training from the last saved training snapshot |
| 66 | + |
| 67 | + |
| 68 | +Graceful restarts |
| 69 | +~~~~~~~~~~~~~~~~~~~~~ |
| 70 | +For graceful restarts, you should structure your train script like: |
| 71 | + |
| 72 | +.. code:: python |
| 73 | +
|
| 74 | + def main(): |
| 75 | + load_snapshot(snapshot_path) |
| 76 | + initialize() |
| 77 | + train() |
| 78 | +
|
| 79 | + def train(): |
| 80 | + for batch in iter(dataset): |
| 81 | + train_step(batch) |
| 82 | +
|
| 83 | + if should_checkpoint: |
| 84 | + save_snapshot(snapshot_path) |
| 85 | +
|
| 86 | +If a failure occurs, ``torchrun`` will terminate all the processes and restart them. |
| 87 | +Each process entrypoint first loads and initializes the last saved snapshot, and continues training from there. |
| 88 | +So at any failure, you only lose the training progress from the last saved snapshot. |
| 89 | + |
| 90 | +In elastic training, whenever there are any membership changes (adding or removing nodes), ``torchrun`` will terminate and spawn processes |
| 91 | +on available devices. Having this structure ensures your training job can continue without manual intervention. |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | + |
| 96 | + |
| 97 | +Diff for `multigpu.py <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu.py>`__ v/s `multigpu_torchrun.py <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py>`__ |
| 98 | +----------------------------------------------------------- |
| 99 | + |
| 100 | +Process group initialization |
| 101 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 102 | + |
| 103 | +- ``torchrun`` assigns ``RANK`` and ``WORLD_SIZE`` automatically, |
| 104 | + amongst `other env |
| 105 | + variables <https://pytorch.org/docs/stable/elastic/run.html#environment-variables>`__ |
| 106 | + |
| 107 | +.. code:: diff |
| 108 | +
|
| 109 | + - def ddp_setup(rank, world_size): |
| 110 | + + def ddp_setup(): |
| 111 | + - """ |
| 112 | + - Args: |
| 113 | + - rank: Unique identifier of each process |
| 114 | + - world_size: Total number of processes |
| 115 | + - """ |
| 116 | + - os.environ["MASTER_ADDR"] = "localhost" |
| 117 | + - os.environ["MASTER_PORT"] = "12355" |
| 118 | + - init_process_group(backend="nccl", rank=rank, world_size=world_size) |
| 119 | + + init_process_group(backend="nccl") |
| 120 | +
|
| 121 | +
|
| 122 | +Use Torchrun-provided env variables |
| 123 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 124 | + |
| 125 | +.. code:: diff |
| 126 | +
|
| 127 | + - self.gpu_id = gpu_id |
| 128 | + + self.gpu_id = int(os.environ["LOCAL_RANK"]) |
| 129 | +
|
| 130 | +Saving and loading snapshots |
| 131 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 132 | + |
| 133 | +Regularly storing all the relevant information in snapshots allows our |
| 134 | +training job to seamlessly resume after an interruption. |
| 135 | + |
| 136 | +.. code:: diff |
| 137 | +
|
| 138 | + + def _save_snapshot(self, epoch): |
| 139 | + + snapshot = {} |
| 140 | + + snapshot["MODEL_STATE"] = self.model.module.state_dict() |
| 141 | + + snapshot["EPOCHS_RUN"] = epoch |
| 142 | + + torch.save(snapshot, "snapshot.pt") |
| 143 | + + print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt") |
| 144 | +
|
| 145 | + + def _load_snapshot(self, snapshot_path): |
| 146 | + + snapshot = torch.load(snapshot_path) |
| 147 | + + self.model.load_state_dict(snapshot["MODEL_STATE"]) |
| 148 | + + self.epochs_run = snapshot["EPOCHS_RUN"] |
| 149 | + + print(f"Resuming training from snapshot at Epoch {self.epochs_run}") |
| 150 | +
|
| 151 | +
|
| 152 | +Loading a snapshot in the Trainer constructor |
| 153 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 154 | + |
| 155 | +When restarting an interrupted training job, your script will first try |
| 156 | +to load a snapshot to resume training from. |
| 157 | + |
| 158 | +.. code:: diff |
| 159 | +
|
| 160 | + class Trainer: |
| 161 | + def __init__(self, snapshot_path, ...): |
| 162 | + ... |
| 163 | + + if os.path.exists(snapshot_path): |
| 164 | + + self._load_snapshot(snapshot_path) |
| 165 | + ... |
| 166 | +
|
| 167 | +
|
| 168 | +Resuming training |
| 169 | +~~~~~~~~~~~~~~~~~ |
| 170 | + |
| 171 | +Training can resume from the last epoch run, instead of starting all |
| 172 | +over from scratch. |
| 173 | + |
| 174 | +.. code:: diff |
| 175 | +
|
| 176 | + def train(self, max_epochs: int): |
| 177 | + - for epoch in range(max_epochs): |
| 178 | + + for epoch in range(self.epochs_run, max_epochs): |
| 179 | + self._run_epoch(epoch) |
| 180 | +
|
| 181 | +
|
| 182 | +Running the script |
| 183 | +~~~~~~~~~~~~~~~~~~ |
| 184 | +Simply call your entrypoint function as you would for a non-multiprocessing script; ``torchrun`` automatically |
| 185 | +spawns the processes. |
| 186 | + |
| 187 | +.. code:: diff |
| 188 | +
|
| 189 | + if __name__ == "__main__": |
| 190 | + import sys |
| 191 | + total_epochs = int(sys.argv[1]) |
| 192 | + save_every = int(sys.argv[2]) |
| 193 | + - world_size = torch.cuda.device_count() |
| 194 | + - mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size) |
| 195 | + + main(save_every, total_epochs) |
| 196 | +
|
| 197 | +
|
| 198 | +.. code:: diff |
| 199 | +
|
| 200 | + - python multigpu.py 50 10 |
| 201 | + + torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10 |
| 202 | +
|
| 203 | +Further Reading |
| 204 | +--------------- |
| 205 | + |
| 206 | +- `Multi-Node training with DDP <../intermediate/ddp_series_multinode.html>`__ (next tutorial in this series) |
| 207 | +- `Multi-GPU Training with DDP <ddp_series_multigpu.html>`__ (previous tutorial in this series) |
| 208 | +- `torchrun <https://pytorch.org/docs/stable/elastic/run.html>`__ |
| 209 | +- `Torchrun launch |
| 210 | + options <https://github.com/pytorch/pytorch/blob/bbe803cb35948df77b46a2d38372910c96693dcd/torch/distributed/run.py#L401>`__ |
| 211 | +- `Migrating from torch.distributed.launch to |
| 212 | + torchrun <https://pytorch.org/docs/stable/elastic/train_script.html#elastic-train-script>`__ |
0 commit comments