Skip to content

[docs] Distributed inference #3376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
title: Text-guided image-inpainting
- local: using-diffusers/depth2img
title: Text-guided depth-to-image
- local: training/distributed_inference
title: Distributed inference with multiple GPUs
- local: using-diffusers/reusing_seeds
title: Improve image quality with deterministic generation
- local: using-diffusers/reproducibility
Expand Down
97 changes: 97 additions & 0 deletions docs/source/en/training/distributed_inference.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Distributed inference with multiple GPUs

On distributed setups, you can run inference across multiple GPUs with [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) or 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index), which is useful for generating multiple prompts in parallel.

This guide will show you how to use PyTorch Distributed and 🤗 Accelerate for distributed inference.

## PyTorch Distributed

PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables faster data parallelism.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables faster data parallelism.
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.

Why "faster"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was trying to describe how it's different from just DataParallel, but I guess it doesn't really make sense if I don't mention it at all! 😅 Better to just remove it and keep it simple :)


To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]:

```py
#!/usr/bin/env python3
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from diffusers import DiffusionPipeline

sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```

You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` would be 2.

Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt:

```py
def run_inference(rank, world_size):
dist.init_process_group("gloo", rank=rank, world_size=world_size)

sd.to(rank)

if torch.distributed.get_rank() == 0:
prompt = "a dog"
elif torch.distributed.get_rank() == 1:
prompt = "a cat"

image = sd(prompt).images[0]
image.save(f"./{'_'.join(prompt)}.png")
```

To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`:

```py
def main():
world_size = 2
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
main()
```

Once you've completed the inference script, run it like:

```bash
torchrun run_distributed.py
```

## 🤗 Accelerate

🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code.

Start by initializing a [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `state.device` to assign a GPU to each process and use `process_index` to assign a GPU to each prompt:

```py
#!/usr/bin/env python3
from accelerate import PartialState
from diffusers import DiffusionPipeline

sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)


def main():
state = PartialState()

sd.to(state.device)

if state.process_index == 0:
prompt = "a dog"
elif state.process_index == 1:
prompt = "a cat"

image = sd(prompt).images[0]
image.save(f"./{'_'.join(prompt)}.png")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we may worry about deadlocking the CPU here by accident no? (Haven't fully tried this yet, just thinking of how we do things in Accelerate that's close to this). So e.g. might be better to have a main_process_first during the save portion to be safe. (Again, something I'd need to play with, just voicing my thoughts here)



if __name__ == "__main__":
main()
```

Call `accelerate launch` to run the distributed inference script:

```bash
accelerate launch run_distributed.py
```