Skip to content

[WIP] Add UFOGen Pipeline and Scheduler #6133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@
title: ScoreSdeVeScheduler
- local: api/schedulers/score_sde_vp
title: ScoreSdeVpScheduler
- local: api/schedulers/ufogen
title: UFOGenScheduler
- local: api/schedulers/unipc
title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion
Expand Down
15 changes: 15 additions & 0 deletions docs/source/en/api/schedulers/ufogen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# UFOGen Multistep and Single-Step Scheduler

## Overview

Multistep and onestep scheduler introduced with the UFOGen model in the paper [UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://arxiv.org/abs/2311.09257) by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou.
This scheduler should be able to generate good samples from a UFOGen model in 1-4 steps.

<Tip warning={true}>

Multistep sampling support is currently experimental.

</Tip>

## UFOGenScheduler
[[autodoc]] UFOGenScheduler
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
"RePaintScheduler",
"SchedulerMixin",
"ScoreSdeVeScheduler",
"UFOGenScheduler",
"UnCLIPScheduler",
"UniPCMultistepScheduler",
"VQDiffusionScheduler",
Expand Down Expand Up @@ -529,6 +530,7 @@
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
UFOGenScheduler,
UnCLIPScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
_import_structure["scheduling_ufogen"] = ["UFOGenScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
Expand Down Expand Up @@ -151,6 +152,7 @@
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_ufogen import UFOGenScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
Expand Down
525 changes: 525 additions & 0 deletions src/diffusers/schedulers/scheduling_ufogen.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class UFOGenScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class UnCLIPScheduler(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
21 changes: 21 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UFOGenScheduler,
UNet2DConditionModel,
logging,
)
Expand Down Expand Up @@ -242,6 +243,26 @@ def test_stable_diffusion_lcm_custom_timesteps(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_ufogen(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe.scheduler = UFOGenScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.3260, 0.4523, 0.4684, 0.3544, 0.3981, 0.4635, 0.5140, 0.3425, 0.4062])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
Expand Down
182 changes: 182 additions & 0 deletions tests/schedulers/test_scheduler_ufogen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch

from diffusers import UFOGenScheduler

from .test_schedulers import SchedulerCommonTest


class UFOGenSchedulerTest(SchedulerCommonTest):
scheduler_classes = (UFOGenScheduler,)

def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"clip_sample": True,
}

config.update(**kwargs)
return config

def test_timesteps(self):
for num_train_timesteps, timestep in zip([1, 5, 100, 1000], [0, 0, 0, 0]):
self.check_over_configs(num_train_timesteps=num_train_timesteps, time_step=timestep)

def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)

def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)

def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)

def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)

def test_time_indices(self):
for t in [0, 500, 999]:
self.check_over_forward(time_step=t)

def full_loop(self, num_inference_steps=10, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)

scheduler.set_timesteps(num_inference_steps)

for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, generator=generator).prev_sample

return sample

def test_full_loop_no_noise(self):
sample = self.full_loop()

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 201.3429) < 1e-2
assert abs(result_mean.item() - 0.2622) < 1e-3

def test_full_loop_no_noise_onestep(self):
sample = self.full_loop(num_inference_steps=1)

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 61.5819) < 1e-2
assert abs(result_mean.item() - 0.0802) < 1e-3

def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 141.1963) < 1e-2
assert abs(result_mean.item() - 0.1838) < 1e-3

def test_full_loop_with_noise(self, num_inference_steps=10):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)

t_start = num_inference_steps - 2
scheduler.set_timesteps(num_inference_steps)

# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])

for t in timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, generator=generator).prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 350.1980) < 1e-2, f" expected result sum 387.9466, but get {result_sum}"
assert abs(result_mean.item() - 0.4560) < 1e-3, f" expected result mean 0.5051, but get {result_mean}"

def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

timesteps = [100, 87, 50, 1, 0]

scheduler.set_timesteps(timesteps=timesteps)

scheduler_timesteps = scheduler.timesteps

for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]

prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()

self.assertEqual(prev_t, expected_prev_t)

def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

timesteps = [100, 87, 50, 51, 0]

with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)

def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)

with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)

def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

timesteps = [scheduler.config.num_train_timesteps]

with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)