Skip to content

Add StableDiffusionXLControlNetPAGImg2ImgPipeline #8990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 21, 2024

Conversation

satani99
Copy link
Contributor

@satani99 satani99 commented Jul 26, 2024

What does this PR do?

fix #8700

Before submitting

Who can review?

@yiyixuxu

@satani99
Copy link
Contributor Author

Generation code

import torch
import numpy as np
from PIL import Image 

from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import ControlNetModel, AutoencoderKL, AutoPipelineForImage2Image
from diffusers.utils import load_image

depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained(
        "diffusers/controlnet-depth-sdxl-1.0-small",
        variant="fp16",
        use_safetensors="True",
        torch_dtype=torch.float16,
        )
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = AutoPipelineForImage2Image.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        controlnet=controlnet,
        vae=vae,
        variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.float16,
        enable_pag=True,
        )
pipe.enable_model_cpu_offload()

def get_depth_map(image):
   image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
   with torch.no_grad(), torch.autocast("cuda"):
       depth_map = depth_estimator(image).predicted_depth

   depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(1024, 1024),
        mode="bicubic",
        align_corners=False,
    )
   depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
   depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
   depth_map = (depth_map - depth_min) / (depth_max - depth_min)
   image = torch.cat([depth_map] * 3, dim=1)
   image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
   image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
   return image



prompt = "A robot, 4k photo"
image = load_image(
        "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
        "/kandinsky/cat.png"
        ).resize((1024, 1024))

controlnet_conditioning_scale = 0.5 
depth_image = get_depth_map(image)

images = pipe(
        prompt,
        image=image,
        control_image=depth_image,
        strength=0.99,
        num_inference_steps=50,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        ).images
images[0].save(f"robot_cat.png")

It works with enable_pag=False but gives error when enable_pag=True.

Error: AttributeError: 'Image' object has no attribute 'shape'. Did you mean: 'save'?

@satani99
Copy link
Contributor Author

Any help would be nice. Thanks

@yiyixuxu
Copy link
Collaborator

can you share the full stack trace?

@satani99
Copy link
Contributor Author

File "/home/nikhil/Desktop/pag.py", line 72, in <module> images = pipe( File "/home/nikhil/miniconda3/envs/pag/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/nikhil/Desktop/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py", line 1422, in __call__ height, width = control_image.shape[-2:] AttributeError: 'Image' object has no attribute 'shape'. Did you mean: 'save'?
pag.py is the above script.

@satani99
Copy link
Contributor Author

hi @yiyixuxu can you review this?

@yiyixuxu
Copy link
Collaborator

sorry this PR got lost too
could you resolve the conflicts?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

@asomoza
I tested it, and it works now
do you want to check if it works as expected for you? (no worries if you don't have time)

@asomoza
Copy link
Member

asomoza commented Aug 21, 2024

@yiyixuxu Tested it and seems ok, it's harder to see the difference here because the base image helps a lot even without PAG, but it still works similar to the other ones.

w/o pag with pag
20240821035941_1195154797 20240821035852_1195154797

@yiyixuxu yiyixuxu merged commit 9003d75 into huggingface:main Aug 21, 2024
15 checks passed
@yiyixuxu
Copy link
Collaborator

@satani99 thank you!

@satani99 satani99 deleted the sdxl_pag branch August 21, 2024 17:37
@yiyixuxu yiyixuxu added the PAG label Sep 4, 2024
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Added pad controlnet sdxl img2img pipeline

---------

Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[PAG] add StableDiffusionXLControlNetPAGImg2ImgPipeline
4 participants