Skip to content

Commit a308b4e

Browse files
authored
Update DDP tutorial for the correct order of set_device (#1285)
1 parent 26de419 commit a308b4e

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

distributed/ddp-tutorial-series/multigpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def ddp_setup(rank, world_size):
1818
"""
1919
os.environ["MASTER_ADDR"] = "localhost"
2020
os.environ["MASTER_PORT"] = "12355"
21-
init_process_group(backend="nccl", rank=rank, world_size=world_size)
2221
torch.cuda.set_device(rank)
22+
init_process_group(backend="nccl", rank=rank, world_size=world_size)
2323

2424
class Trainer:
2525
def __init__(
@@ -99,6 +99,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
9999
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
100100
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
101101
args = parser.parse_args()
102-
102+
103103
world_size = torch.cuda.device_count()
104104
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

distributed/ddp-tutorial-series/multigpu_torchrun.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
def ddp_setup():
14-
init_process_group(backend="nccl")
1514
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
15+
init_process_group(backend="nccl")
1616

1717
class Trainer:
1818
def __init__(
@@ -107,5 +107,5 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
107107
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
108108
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
109109
args = parser.parse_args()
110-
110+
111111
main(args.save_every, args.total_epochs, args.batch_size)

distributed/ddp-tutorial-series/multinode.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
def ddp_setup():
14-
init_process_group(backend="nccl")
1514
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
15+
init_process_group(backend="nccl")
1616

1717
class Trainer:
1818
def __init__(
@@ -108,5 +108,5 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
108108
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
109109
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
110110
args = parser.parse_args()
111-
111+
112112
main(args.save_every, args.total_epochs, args.batch_size)

0 commit comments

Comments
 (0)