-
Notifications
You must be signed in to change notification settings - Fork 6k
[docs] Distributed inference #3376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
c9da3dc
877bcaa
fb1adb9
d8da732
01bcbc4
b12e5f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,97 @@ | ||||||
# Distributed inference with multiple GPUs | ||||||
|
||||||
On distributed setups, you can run inference across multiple GPUs with [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) or 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index), which is useful for generating multiple prompts in parallel. | ||||||
|
||||||
This guide will show you how to use PyTorch Distributed and 🤗 Accelerate for distributed inference. | ||||||
|
||||||
## PyTorch Distributed | ||||||
|
||||||
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables faster data parallelism. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Why "faster"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I was trying to describe how it's different from just |
||||||
|
||||||
To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]: | ||||||
|
||||||
```py | ||||||
#!/usr/bin/env python3 | ||||||
stevhliu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
import torch | ||||||
import torch.distributed as dist | ||||||
import torch.multiprocessing as mp | ||||||
|
||||||
from diffusers import DiffusionPipeline | ||||||
|
||||||
sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) | ||||||
``` | ||||||
|
||||||
You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` would be 2. | ||||||
|
||||||
Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt: | ||||||
|
||||||
```py | ||||||
def run_inference(rank, world_size): | ||||||
dist.init_process_group("gloo", rank=rank, world_size=world_size) | ||||||
stevhliu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
sd.to(rank) | ||||||
|
||||||
if torch.distributed.get_rank() == 0: | ||||||
prompt = "a dog" | ||||||
elif torch.distributed.get_rank() == 1: | ||||||
prompt = "a cat" | ||||||
|
||||||
image = sd(prompt).images[0] | ||||||
image.save(f"./{'_'.join(prompt)}.png") | ||||||
``` | ||||||
|
||||||
To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`: | ||||||
|
||||||
```py | ||||||
def main(): | ||||||
world_size = 2 | ||||||
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True) | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
main() | ||||||
``` | ||||||
|
||||||
Once you've completed the inference script, run it like: | ||||||
|
||||||
```bash | ||||||
torchrun run_distributed.py | ||||||
stevhliu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
``` | ||||||
|
||||||
## 🤗 Accelerate | ||||||
stevhliu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code. | ||||||
|
||||||
Start by initializing a [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `state.device` to assign a GPU to each process and use `process_index` to assign a GPU to each prompt: | ||||||
|
||||||
```py | ||||||
#!/usr/bin/env python3 | ||||||
from accelerate import PartialState | ||||||
from diffusers import DiffusionPipeline | ||||||
|
||||||
sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) | ||||||
|
||||||
|
||||||
def main(): | ||||||
state = PartialState() | ||||||
|
||||||
sd.to(state.device) | ||||||
|
||||||
if state.process_index == 0: | ||||||
prompt = "a dog" | ||||||
elif state.process_index == 1: | ||||||
prompt = "a cat" | ||||||
|
||||||
image = sd(prompt).images[0] | ||||||
image.save(f"./{'_'.join(prompt)}.png") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like we may worry about deadlocking the CPU here by accident no? (Haven't fully tried this yet, just thinking of how we do things in Accelerate that's close to this). So e.g. might be better to have a |
||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
main() | ||||||
``` | ||||||
|
||||||
Call `accelerate launch` to run the distributed inference script: | ||||||
|
||||||
```bash | ||||||
accelerate launch run_distributed.py | ||||||
``` |
Uh oh!
There was an error while loading. Please reload this page.