Skip to content

Commit e589bdb

Browse files
authored
[docs] Distributed inference (#3376)
* distributed inference * move to inference section * apply feedback * update with split_between_processes * apply feedback
1 parent 00c76f6 commit e589bdb

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
title: Text-guided depth-to-image
4747
- local: using-diffusers/textual_inversion_inference
4848
title: Textual inversion
49+
- local: training/distributed_inference
50+
title: Distributed inference with multiple GPUs
4951
- local: using-diffusers/reusing_seeds
5052
title: Improve image quality with deterministic generation
5153
- local: using-diffusers/reproducibility
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Distributed inference with multiple GPUs
2+
3+
On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
4+
5+
This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference.
6+
7+
## 🤗 Accelerate
8+
9+
🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code.
10+
11+
To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process.
12+
13+
Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
14+
15+
```py
16+
from accelerate import PartialState
17+
from diffusers import DiffusionPipeline
18+
19+
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
20+
distributed_state = PartialState()
21+
pipeline.to(distributed_state.device)
22+
23+
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
24+
result = pipeline(prompt).images[0]
25+
result.save(f"result_{distributed_state.process_index}.png")
26+
```
27+
28+
Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
29+
30+
```bash
31+
accelerate launch run_distributed.py --num_processes=2
32+
```
33+
34+
<Tip>
35+
36+
To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
37+
38+
</Tip>
39+
40+
## PyTorch Distributed
41+
42+
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
43+
44+
To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]:
45+
46+
```py
47+
import torch
48+
import torch.distributed as dist
49+
import torch.multiprocessing as mp
50+
51+
from diffusers import DiffusionPipeline
52+
53+
sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
54+
```
55+
56+
You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` is 2.
57+
58+
Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt:
59+
60+
```py
61+
def run_inference(rank, world_size):
62+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
63+
64+
sd.to(rank)
65+
66+
if torch.distributed.get_rank() == 0:
67+
prompt = "a dog"
68+
elif torch.distributed.get_rank() == 1:
69+
prompt = "a cat"
70+
71+
image = sd(prompt).images[0]
72+
image.save(f"./{'_'.join(prompt)}.png")
73+
```
74+
75+
To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`:
76+
77+
```py
78+
def main():
79+
world_size = 2
80+
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
81+
82+
83+
if __name__ == "__main__":
84+
main()
85+
```
86+
87+
Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script:
88+
89+
```bash
90+
torchrun run_distributed.py --nproc_per_node=2
91+
```

0 commit comments

Comments
 (0)