|
50 | 50 | )
|
51 | 51 | from diffusers.optimization import get_scheduler
|
52 | 52 | from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
53 |
| -from diffusers.utils import check_min_version, is_wandb_available |
| 53 | +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid |
54 | 54 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
55 | 55 | from diffusers.utils.torch_utils import is_compiled_module
|
56 | 56 |
|
|
64 | 64 | logger = get_logger(__name__)
|
65 | 65 |
|
66 | 66 |
|
67 |
| -def image_grid(imgs, rows, cols): |
68 |
| - assert len(imgs) == rows * cols |
69 |
| - |
70 |
| - w, h = imgs[0].size |
71 |
| - grid = Image.new("RGB", size=(cols * w, rows * h)) |
72 |
| - |
73 |
| - for i, img in enumerate(imgs): |
74 |
| - grid.paste(img, box=(i % cols * w, i // cols * h)) |
75 |
| - return grid |
76 |
| - |
77 |
| - |
78 | 67 | def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
|
79 | 68 | logger.info("Running validation... ")
|
80 | 69 |
|
@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
224 | 213 | validation_image.save(os.path.join(repo_folder, "image_control.png"))
|
225 | 214 | img_str += f"prompt: {validation_prompt}\n"
|
226 | 215 | images = [validation_image] + images
|
227 |
| - image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) |
| 216 | + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) |
228 | 217 | img_str += f"\n"
|
229 | 218 |
|
230 | 219 | model_description = f"""
|
|
0 commit comments