Skip to content

Commit 5181972

Browse files
authored
[WIP] Check UNet shapes in StableDiffusionInpaintPipeline __init__ (huggingface#2853)
Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9.
1 parent 28c6c6b commit 5181972

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,14 @@ def __init__(
243243
new_config = dict(unet.config)
244244
new_config["sample_size"] = 64
245245
unet._internal_dict = FrozenDict(new_config)
246+
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
247+
if unet.config.in_channels != 9:
248+
logger.warning(
249+
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
250+
f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`,"
251+
" 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify"
252+
" this behavior, please check whether you have loaded the right checkpoint."
253+
)
246254

247255
self.register_modules(
248256
vae=vae,

0 commit comments

Comments
 (0)