Skip to content

Commit 71317a8

Browse files
DavyMorgansayakpaul
authored andcommitted
Update sd3 controlnet example (#9735)
* use make_image_grid in diffusers.utils * use checkpoint on the Hub
1 parent fdc716f commit 71317a8

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

examples/controlnet/README_sd3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ from diffusers.utils import load_image
104104
import torch
105105

106106
base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
107-
controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet"
107+
controlnet_path = "DavyMorgan/sd3-controlnet-out"
108108

109109
controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
110110
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(

examples/controlnet/train_controlnet_sd3.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151
from diffusers.optimization import get_scheduler
5252
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
53-
from diffusers.utils import check_min_version, is_wandb_available
53+
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
5454
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5555
from diffusers.utils.torch_utils import is_compiled_module
5656

@@ -64,17 +64,6 @@
6464
logger = get_logger(__name__)
6565

6666

67-
def image_grid(imgs, rows, cols):
68-
assert len(imgs) == rows * cols
69-
70-
w, h = imgs[0].size
71-
grid = Image.new("RGB", size=(cols * w, rows * h))
72-
73-
for i, img in enumerate(imgs):
74-
grid.paste(img, box=(i % cols * w, i // cols * h))
75-
return grid
76-
77-
7867
def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
7968
logger.info("Running validation... ")
8069

@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
224213
validation_image.save(os.path.join(repo_folder, "image_control.png"))
225214
img_str += f"prompt: {validation_prompt}\n"
226215
images = [validation_image] + images
227-
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
216+
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
228217
img_str += f"![images_{i})](./images_{i}.png)\n"
229218

230219
model_description = f"""

0 commit comments

Comments
 (0)