Skip to content

Commit 80bc0c0

Browse files
config fixes (#3060)
1 parent 091a058 commit 80bc0c0

File tree

4 files changed

+15
-12
lines changed

4 files changed

+15
-12
lines changed

examples/community/sd_text2img_k_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
)
106106

107107
model = ModelWrapper(unet, scheduler.alphas_cumprod)
108-
if scheduler.prediction_type == "v_prediction":
108+
if scheduler.config.prediction_type == "v_prediction":
109109
self.k_diffusion_model = CompVisVDenoiser(model)
110110
else:
111111
self.k_diffusion_model = CompVisDenoiser(model)

src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def get_input_dims(self) -> Tuple:
6060
input_module = self.vqvae if self.vqvae is not None else self.unet
6161
# For backwards compatibility
6262
sample_size = (
63-
(input_module.sample_size, input_module.sample_size)
64-
if type(input_module.sample_size) == int
65-
else input_module.sample_size
63+
(input_module.config.sample_size, input_module.config.sample_size)
64+
if type(input_module.config.sample_size) == int
65+
else input_module.config.sample_size
6666
)
6767
return sample_size
6868

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
114114

115115
model = ModelWrapper(unet, scheduler.alphas_cumprod)
116-
if scheduler.prediction_type == "v_prediction":
116+
if scheduler.config.prediction_type == "v_prediction":
117117
self.k_diffusion_model = CompVisVDenoiser(model)
118118
else:
119119
self.k_diffusion_model = CompVisDenoiser(model)

tests/pipelines/audio_diffusion/test_audio_diffusion.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,11 @@ def test_audio_diffusion(self):
115115
output = pipe(generator=generator, steps=4, return_dict=False)
116116
image_from_tuple = output[0][0]
117117

118-
assert audio.shape == (1, (self.dummy_unet.sample_size[1] - 1) * mel.hop_length)
119-
assert image.height == self.dummy_unet.sample_size[0] and image.width == self.dummy_unet.sample_size[1]
118+
assert audio.shape == (1, (self.dummy_unet.config.sample_size[1] - 1) * mel.hop_length)
119+
assert (
120+
image.height == self.dummy_unet.config.sample_size[0]
121+
and image.width == self.dummy_unet.config.sample_size[1]
122+
)
120123
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
121124
image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10]
122125
expected_slice = np.array([69, 255, 255, 255, 0, 0, 77, 181, 12, 127])
@@ -133,14 +136,14 @@ def test_audio_diffusion(self):
133136
pipe.set_progress_bar_config(disable=None)
134137

135138
np.random.seed(0)
136-
raw_audio = np.random.uniform(-1, 1, ((dummy_vqvae_and_unet[0].sample_size[1] - 1) * mel.hop_length,))
139+
raw_audio = np.random.uniform(-1, 1, ((dummy_vqvae_and_unet[0].config.sample_size[1] - 1) * mel.hop_length,))
137140
generator = torch.Generator(device=device).manual_seed(42)
138141
output = pipe(raw_audio=raw_audio, generator=generator, start_step=5, steps=10)
139142
image = output.images[0]
140143

141144
assert (
142-
image.height == self.dummy_vqvae_and_unet[0].sample_size[0]
143-
and image.width == self.dummy_vqvae_and_unet[0].sample_size[1]
145+
image.height == self.dummy_vqvae_and_unet[0].config.sample_size[0]
146+
and image.width == self.dummy_vqvae_and_unet[0].config.sample_size[1]
144147
)
145148
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
146149
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
@@ -183,8 +186,8 @@ def test_audio_diffusion(self):
183186
audio = output.audios[0]
184187
image = output.images[0]
185188

186-
assert audio.shape == (1, (pipe.unet.sample_size[1] - 1) * pipe.mel.hop_length)
187-
assert image.height == pipe.unet.sample_size[0] and image.width == pipe.unet.sample_size[1]
189+
assert audio.shape == (1, (pipe.unet.config.sample_size[1] - 1) * pipe.mel.hop_length)
190+
assert image.height == pipe.unet.config.sample_size[0] and image.width == pipe.unet.config.sample_size[1]
188191
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
189192
expected_slice = np.array([151, 167, 154, 144, 122, 134, 121, 105, 70, 26])
190193

0 commit comments

Comments
 (0)