Skip to content

Update Noise Autocorrelation Loss Function for Pix2PixZero Pipeline #2942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...utils import (
PIL_INTERPOLATION,
BaseOutput,
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
Expand Down Expand Up @@ -721,23 +722,31 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None
)

if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
latents = torch.cat(latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)

init_latents = self.vae.config.scaling_factor * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
latents = self.vae.encode(image).latent_dist.sample(generator)

latents = self.vae.config.scaling_factor * latents

if batch_size != latents.shape[0]:
if batch_size % latents.shape[0] == 0:
# expand image_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_latents_per_image = batch_size // latents.shape[0]
latents = torch.cat([latents] * additional_latents_per_image, dim=0)
else:
raise ValueError(
f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)

latents = init_latents
latents = torch.cat([latents], dim=0)

return latents

Expand All @@ -759,23 +768,18 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep
)

def auto_corr_loss(self, hidden_states, generator=None):
batch_size, channel, height, width = hidden_states.shape
if batch_size > 1:
raise ValueError("Only batch_size 1 is supported for now")

hidden_states = hidden_states.squeeze(0)
# hidden_states must be shape [C,H,W] now
reg_loss = 0.0
for i in range(hidden_states.shape[0]):
noise = hidden_states[i][None, None, :, :]
while True:
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2

if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
for j in range(hidden_states.shape[1]):
noise = hidden_states[i : i + 1, j : j + 1, :, :]
while True:
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2

if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
return reg_loss

def kl_divergence(self, hidden_states):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

import gc
import random
import tempfile
import unittest

import numpy as np
Expand All @@ -30,7 +32,7 @@
StableDiffusionPix2PixZeroPipeline,
UNet2DConditionModel,
)
from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps

from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
Expand Down Expand Up @@ -69,6 +71,7 @@ def get_dummy_components(self):
cross_attention_dim=32,
)
scheduler = DDIMScheduler()
inverse_scheduler = DDIMInverseScheduler()
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
Expand Down Expand Up @@ -101,7 +104,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"inverse_scheduler": None,
"inverse_scheduler": inverse_scheduler,
"caption_generator": None,
"caption_processor": None,
}
Expand All @@ -122,6 +125,90 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs

def get_dummy_inversion_inputs(self, device, seed=0):
dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
generator = torch.manual_seed(seed)

inputs = {
"prompt": [
"A painting of a squirrel eating a burger",
"A painting of a burger eating a squirrel",
],
"image": dummy_image.cpu(),
"num_inference_steps": 2,
"guidance_scale": 6.0,
"generator": generator,
"output_type": "numpy",
}
return inputs

def test_save_load_optional_components(self):
if not hasattr(self.pipeline_class, "_optional_components"):
return

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# set all optional components to None and update pipeline config accordingly
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})

inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)

for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)

inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]

max_diff = np.abs(output - output_loaded).max()
self.assertLess(max_diff, 1e-4)

def test_stable_diffusion_pix2pix_zero_inversion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inversion_inputs(device)
inputs["image"] = inputs["image"][:1]
inputs["prompt"] = inputs["prompt"][:1]
image = sd_pipe.invert(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4833, 0.4696, 0.5574, 0.5194, 0.5248, 0.5638, 0.5040, 0.5423, 0.5072])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

def test_stable_diffusion_pix2pix_zero_inversion_batch(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inversion_inputs(device)
image = sd_pipe.invert(**inputs).images
image_slice = image[1, -3:, -3:, -1]
assert image.shape == (2, 32, 32, 3)
expected_slice = np.array([0.6672, 0.5203, 0.4908, 0.4376, 0.4517, 0.5544, 0.4605, 0.4826, 0.5007])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

def test_stable_diffusion_pix2pix_zero_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand Down