Skip to content

Commit 1d686ba

Browse files
sayakpaulpatrickvonplatenDN6a-r-r-o-wEdoardoBotta
authored
[feat: Benchmarking Workflow] add stuff for a benchmarking workflow (#5839)
* add poc for benchmarking workflow. * import * fix argument * fix: argument * fix: path * fix * fix * path * output csv files. * workflow cleanup * append token * add utility to push to hf dataset * fix: kw arg * better reporting * fix: headers * better formatting of the numbers. * better type annotation * fix: formatting * moentarily disable check * push results. * remove disable check * introduce base classes. * img2img class * add inpainting pipeline * intoduce base benchmark class. * add img2img and inpainting * feat: utility to compare changes * fix * fix import * add args * basepath * better exception handling * better path handling * fix * fix * remove * ifx * fix * add: support for controlnet. * image_url -> url * move images to huggingface hub * correct urls. * root_ckpt * flush before benchmarking * don't install accelerate from source * add runner * simplify Diffusers Benchmarking step * change runner * fix: subprocess call. * filter percentage values * fix controlnet benchmark * add t2i adapters. * fix filter columns * fix t2i adapter benchmark * fix init. * fix * remove safetensors flag * fix args print * fix * feat: run_command * add adapter resolution mapping * benchmark t2i adapter fix. * fix adapter input * fix * convert to L. * add flush() add appropriate places * better filtering * okay * get env for torch * convert to float * fix * filter out nans. * better coment * sdxl * sdxl for other benchmarks. * fix: condition * fix: condition for inpainting * fix: mapping for resolution * fix * include kandinsky and wuerstchen * fix: Wuerstchen * Empty-Commit * [Community] AnimateDiff + Controlnet Pipeline (#5928) * begin work on animatediff + controlnet pipeline * complete todos, uncomment multicontrolnet, input checks Co-Authored-By: EdoardoBotta <[email protected]> * update Co-Authored-By: EdoardoBotta <[email protected]> * add example * update community README * Update examples/community/README.md --------- Co-authored-by: EdoardoBotta <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * EulerDiscreteScheduler add `rescale_betas_zero_snr` (#6024) * EulerDiscreteScheduler add `rescale_betas_zero_snr` * Revert "[Community] AnimateDiff + Controlnet Pipeline (#5928)" This reverts commit 821726d. * Revert "EulerDiscreteScheduler add `rescale_betas_zero_snr` (#6024)" This reverts commit 3dc2362. * add SDXL turbo * add lcm lora to the mix as well. * fix * increase steps to 2 when running turbo i2i * debug * debug * debug * fix for good * fix and isolate better * fuse lora so that torch compile works with peft * fix: LCMLoRA * better identification for LCM * change to cron job --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: Aryan V S <[email protected]> Co-authored-by: EdoardoBotta <[email protected]> Co-authored-by: Beinsezii <[email protected]>
1 parent 0a401b9 commit 1d686ba

12 files changed

+791
-1
lines changed

.github/workflows/benchmark.yml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: Benchmarking tests
2+
3+
on:
4+
schedule:
5+
- cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM
6+
7+
env:
8+
DIFFUSERS_IS_CI: yes
9+
HF_HOME: /mnt/cache
10+
OMP_NUM_THREADS: 8
11+
MKL_NUM_THREADS: 8
12+
13+
jobs:
14+
torch_pipelines_cuda_benchmark_tests:
15+
name: Torch Core Pipelines CUDA Benchmarking Tests
16+
strategy:
17+
fail-fast: false
18+
max-parallel: 1
19+
runs-on: [single-gpu, nvidia-gpu, a10, ci]
20+
container:
21+
image: diffusers/diffusers-pytorch-cuda
22+
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
23+
steps:
24+
- name: Checkout diffusers
25+
uses: actions/checkout@v3
26+
with:
27+
fetch-depth: 2
28+
- name: NVIDIA-SMI
29+
run: |
30+
nvidia-smi
31+
- name: Install dependencies
32+
run: |
33+
apt-get update && apt-get install libsndfile1-dev libgl1 -y
34+
python -m pip install -e .[quality,test]
35+
python -m pip install pandas
36+
- name: Environment
37+
run: |
38+
python utils/print_env.py
39+
- name: Diffusers Benchmarking
40+
env:
41+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
42+
BASE_PATH: benchmark_outputs
43+
run: |
44+
export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
45+
cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
46+
47+
- name: Test suite reports artifacts
48+
if: ${{ always() }}
49+
uses: actions/upload-artifact@v2
50+
with:
51+
name: benchmark_test_reports
52+
path: benchmarks/benchmark_outputs

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
44
export PYTHONPATH = src
55

6-
check_dirs := examples scripts src tests utils
6+
check_dirs := examples scripts src tests utils benchmarks
77

88
modified_only_fixup:
99
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))

benchmarks/base_classes.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
6+
from diffusers import (
7+
AutoPipelineForImage2Image,
8+
AutoPipelineForInpainting,
9+
AutoPipelineForText2Image,
10+
ControlNetModel,
11+
LCMScheduler,
12+
StableDiffusionAdapterPipeline,
13+
StableDiffusionControlNetPipeline,
14+
StableDiffusionXLAdapterPipeline,
15+
StableDiffusionXLControlNetPipeline,
16+
T2IAdapter,
17+
WuerstchenCombinedPipeline,
18+
)
19+
from diffusers.utils import load_image
20+
21+
22+
sys.path.append(".")
23+
24+
from utils import ( # noqa: E402
25+
BASE_PATH,
26+
PROMPT,
27+
BenchmarkInfo,
28+
benchmark_fn,
29+
bytes_to_giga_bytes,
30+
flush,
31+
generate_csv_dict,
32+
write_to_csv,
33+
)
34+
35+
36+
RESOLUTION_MAPPING = {
37+
"runwayml/stable-diffusion-v1-5": (512, 512),
38+
"lllyasviel/sd-controlnet-canny": (512, 512),
39+
"diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
40+
"TencentARC/t2iadapter_canny_sd14v1": (512, 512),
41+
"TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
42+
"stabilityai/stable-diffusion-2-1": (768, 768),
43+
"stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
44+
"stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
45+
"stabilityai/sdxl-turbo": (512, 512),
46+
}
47+
48+
49+
class BaseBenchmak:
50+
pipeline_class = None
51+
52+
def __init__(self, args):
53+
super().__init__()
54+
55+
def run_inference(self, args):
56+
raise NotImplementedError
57+
58+
def benchmark(self, args):
59+
raise NotImplementedError
60+
61+
def get_result_filepath(self, args):
62+
pipeline_class_name = str(self.pipe.__class__.__name__)
63+
name = (
64+
args.ckpt.replace("/", "_")
65+
+ "_"
66+
+ pipeline_class_name
67+
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
68+
)
69+
filepath = os.path.join(BASE_PATH, name)
70+
return filepath
71+
72+
73+
class TextToImageBenchmark(BaseBenchmak):
74+
pipeline_class = AutoPipelineForText2Image
75+
76+
def __init__(self, args):
77+
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
78+
pipe = pipe.to("cuda")
79+
80+
if args.run_compile:
81+
if not isinstance(pipe, WuerstchenCombinedPipeline):
82+
pipe.unet.to(memory_format=torch.channels_last)
83+
print("Run torch compile")
84+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
85+
86+
if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
87+
pipe.movq.to(memory_format=torch.channels_last)
88+
pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
89+
else:
90+
print("Run torch compile")
91+
pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
92+
pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
93+
94+
pipe.set_progress_bar_config(disable=True)
95+
self.pipe = pipe
96+
97+
def run_inference(self, pipe, args):
98+
_ = pipe(
99+
prompt=PROMPT,
100+
num_inference_steps=args.num_inference_steps,
101+
num_images_per_prompt=args.batch_size,
102+
)
103+
104+
def benchmark(self, args):
105+
flush()
106+
107+
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
108+
109+
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
110+
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
111+
benchmark_info = BenchmarkInfo(time=time, memory=memory)
112+
113+
pipeline_class_name = str(self.pipe.__class__.__name__)
114+
flush()
115+
csv_dict = generate_csv_dict(
116+
pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
117+
)
118+
filepath = self.get_result_filepath(args)
119+
write_to_csv(filepath, csv_dict)
120+
print(f"Logs written to: {filepath}")
121+
flush()
122+
123+
124+
class TurboTextToImageBenchmark(TextToImageBenchmark):
125+
def __init__(self, args):
126+
super().__init__(args)
127+
128+
def run_inference(self, pipe, args):
129+
_ = pipe(
130+
prompt=PROMPT,
131+
num_inference_steps=args.num_inference_steps,
132+
num_images_per_prompt=args.batch_size,
133+
guidance_scale=0.0,
134+
)
135+
136+
137+
class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
138+
lora_id = "latent-consistency/lcm-lora-sdxl"
139+
140+
def __init__(self, args):
141+
super().__init__(args)
142+
self.pipe.load_lora_weights(self.lora_id)
143+
self.pipe.fuse_lora()
144+
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
145+
146+
def get_result_filepath(self, args):
147+
pipeline_class_name = str(self.pipe.__class__.__name__)
148+
name = (
149+
self.lora_id.replace("/", "_")
150+
+ "_"
151+
+ pipeline_class_name
152+
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
153+
)
154+
filepath = os.path.join(BASE_PATH, name)
155+
return filepath
156+
157+
def run_inference(self, pipe, args):
158+
_ = pipe(
159+
prompt=PROMPT,
160+
num_inference_steps=args.num_inference_steps,
161+
num_images_per_prompt=args.batch_size,
162+
guidance_scale=1.0,
163+
)
164+
165+
166+
class ImageToImageBenchmark(TextToImageBenchmark):
167+
pipeline_class = AutoPipelineForImage2Image
168+
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
169+
image = load_image(url).convert("RGB")
170+
171+
def __init__(self, args):
172+
super().__init__(args)
173+
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
174+
175+
def run_inference(self, pipe, args):
176+
_ = pipe(
177+
prompt=PROMPT,
178+
image=self.image,
179+
num_inference_steps=args.num_inference_steps,
180+
num_images_per_prompt=args.batch_size,
181+
)
182+
183+
184+
class TurboImageToImageBenchmark(ImageToImageBenchmark):
185+
def __init__(self, args):
186+
super().__init__(args)
187+
188+
def run_inference(self, pipe, args):
189+
_ = pipe(
190+
prompt=PROMPT,
191+
image=self.image,
192+
num_inference_steps=args.num_inference_steps,
193+
num_images_per_prompt=args.batch_size,
194+
guidance_scale=0.0,
195+
strength=0.5,
196+
)
197+
198+
199+
class InpaintingBenchmark(ImageToImageBenchmark):
200+
pipeline_class = AutoPipelineForInpainting
201+
mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
202+
mask = load_image(mask_url).convert("RGB")
203+
204+
def __init__(self, args):
205+
super().__init__(args)
206+
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
207+
self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
208+
209+
def run_inference(self, pipe, args):
210+
_ = pipe(
211+
prompt=PROMPT,
212+
image=self.image,
213+
mask_image=self.mask,
214+
num_inference_steps=args.num_inference_steps,
215+
num_images_per_prompt=args.batch_size,
216+
)
217+
218+
219+
class ControlNetBenchmark(TextToImageBenchmark):
220+
pipeline_class = StableDiffusionControlNetPipeline
221+
aux_network_class = ControlNetModel
222+
root_ckpt = "runwayml/stable-diffusion-v1-5"
223+
224+
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
225+
image = load_image(url).convert("RGB")
226+
227+
def __init__(self, args):
228+
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
229+
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
230+
pipe = pipe.to("cuda")
231+
232+
pipe.set_progress_bar_config(disable=True)
233+
self.pipe = pipe
234+
235+
if args.run_compile:
236+
pipe.unet.to(memory_format=torch.channels_last)
237+
pipe.controlnet.to(memory_format=torch.channels_last)
238+
239+
print("Run torch compile")
240+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
241+
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
242+
243+
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
244+
245+
def run_inference(self, pipe, args):
246+
_ = pipe(
247+
prompt=PROMPT,
248+
image=self.image,
249+
num_inference_steps=args.num_inference_steps,
250+
num_images_per_prompt=args.batch_size,
251+
)
252+
253+
254+
class ControlNetSDXLBenchmark(ControlNetBenchmark):
255+
pipeline_class = StableDiffusionXLControlNetPipeline
256+
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
257+
258+
def __init__(self, args):
259+
super().__init__(args)
260+
261+
262+
class T2IAdapterBenchmark(ControlNetBenchmark):
263+
pipeline_class = StableDiffusionAdapterPipeline
264+
aux_network_class = T2IAdapter
265+
root_ckpt = "CompVis/stable-diffusion-v1-4"
266+
267+
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
268+
image = load_image(url).convert("L")
269+
270+
def __init__(self, args):
271+
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
272+
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
273+
pipe = pipe.to("cuda")
274+
275+
pipe.set_progress_bar_config(disable=True)
276+
self.pipe = pipe
277+
278+
if args.run_compile:
279+
pipe.unet.to(memory_format=torch.channels_last)
280+
pipe.adapter.to(memory_format=torch.channels_last)
281+
282+
print("Run torch compile")
283+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
284+
pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
285+
286+
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
287+
288+
289+
class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
290+
pipeline_class = StableDiffusionXLAdapterPipeline
291+
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
292+
293+
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
294+
image = load_image(url)
295+
296+
def __init__(self, args):
297+
super().__init__(args)

benchmarks/benchmark_controlnet.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import argparse
2+
import sys
3+
4+
5+
sys.path.append(".")
6+
from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
7+
8+
9+
if __name__ == "__main__":
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument(
12+
"--ckpt",
13+
type=str,
14+
default="lllyasviel/sd-controlnet-canny",
15+
choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
16+
)
17+
parser.add_argument("--batch_size", type=int, default=1)
18+
parser.add_argument("--num_inference_steps", type=int, default=50)
19+
parser.add_argument("--model_cpu_offload", action="store_true")
20+
parser.add_argument("--run_compile", action="store_true")
21+
args = parser.parse_args()
22+
23+
benchmark_pipe = (
24+
ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
25+
)
26+
benchmark_pipe.benchmark(args)

0 commit comments

Comments
 (0)