Skip to content

Commit a9288b4

Browse files
authored
Modularize InstructPix2Pix SDXL inferencing during and after training in examples (#6569)
1 parent c544196 commit a9288b4

File tree

1 file changed

+75
-74
lines changed

1 file changed

+75
-74
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 75 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from diffusers.utils.torch_utils import is_compiled_module
5656

5757

58+
if is_wandb_available():
59+
import wandb
60+
5861
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5962
check_min_version("0.26.0.dev0")
6063

@@ -67,6 +70,57 @@
6770
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
6871

6972

73+
def log_validation(
74+
pipeline,
75+
args,
76+
accelerator,
77+
generator,
78+
global_step,
79+
is_final_validation=False,
80+
):
81+
logger.info(
82+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
83+
f" {args.validation_prompt}."
84+
)
85+
86+
pipeline = pipeline.to(accelerator.device)
87+
pipeline.set_progress_bar_config(disable=True)
88+
89+
val_save_dir = os.path.join(args.output_dir, "validation_images")
90+
if not os.path.exists(val_save_dir):
91+
os.makedirs(val_save_dir)
92+
93+
original_image = (
94+
lambda image_url_or_path: load_image(image_url_or_path)
95+
if urlparse(image_url_or_path).scheme
96+
else Image.open(image_url_or_path).convert("RGB")
97+
)(args.val_image_url_or_path)
98+
99+
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
100+
edited_images = []
101+
# Run inference
102+
for val_img_idx in range(args.num_validation_images):
103+
a_val_img = pipeline(
104+
args.validation_prompt,
105+
image=original_image,
106+
num_inference_steps=20,
107+
image_guidance_scale=1.5,
108+
guidance_scale=7,
109+
generator=generator,
110+
).images[0]
111+
edited_images.append(a_val_img)
112+
# Save validation images
113+
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
114+
115+
for tracker in accelerator.trackers:
116+
if tracker.name == "wandb":
117+
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
118+
for edited_image in edited_images:
119+
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
120+
logger_name = "test" if is_final_validation else "validation"
121+
tracker.log({logger_name: wandb_table})
122+
123+
70124
def import_model_class_from_model_name_or_path(
71125
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
72126
):
@@ -447,11 +501,6 @@ def main():
447501

448502
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
449503

450-
if args.report_to == "wandb":
451-
if not is_wandb_available():
452-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
453-
import wandb
454-
455504
# Make one log on every process with the configuration for debugging.
456505
logging.basicConfig(
457506
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -1111,11 +1160,6 @@ def collate_fn(examples):
11111160
### BEGIN: Perform validation every `validation_epochs` steps
11121161
if global_step % args.validation_steps == 0:
11131162
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
1114-
logger.info(
1115-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1116-
f" {args.validation_prompt}."
1117-
)
1118-
11191163
# create pipeline
11201164
if args.use_ema:
11211165
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1179,16 @@ def collate_fn(examples):
11351179
variant=args.variant,
11361180
torch_dtype=weight_dtype,
11371181
)
1138-
pipeline = pipeline.to(accelerator.device)
1139-
pipeline.set_progress_bar_config(disable=True)
1140-
1141-
# run inference
1142-
# Save validation images
1143-
val_save_dir = os.path.join(args.output_dir, "validation_images")
1144-
if not os.path.exists(val_save_dir):
1145-
os.makedirs(val_save_dir)
1146-
1147-
original_image = (
1148-
lambda image_url_or_path: load_image(image_url_or_path)
1149-
if urlparse(image_url_or_path).scheme
1150-
else Image.open(image_url_or_path).convert("RGB")
1151-
)(args.val_image_url_or_path)
1152-
with torch.autocast(
1153-
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
1154-
):
1155-
edited_images = []
1156-
for val_img_idx in range(args.num_validation_images):
1157-
a_val_img = pipeline(
1158-
args.validation_prompt,
1159-
image=original_image,
1160-
num_inference_steps=20,
1161-
image_guidance_scale=1.5,
1162-
guidance_scale=7,
1163-
generator=generator,
1164-
).images[0]
1165-
edited_images.append(a_val_img)
1166-
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
1167-
1168-
for tracker in accelerator.trackers:
1169-
if tracker.name == "wandb":
1170-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1171-
for edited_image in edited_images:
1172-
wandb_table.add_data(
1173-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1174-
)
1175-
tracker.log({"validation": wandb_table})
1182+
1183+
log_validation(
1184+
pipeline,
1185+
args,
1186+
accelerator,
1187+
generator,
1188+
global_step,
1189+
is_final_validation=False,
1190+
)
1191+
11761192
if args.use_ema:
11771193
# Switch back to the original UNet parameters.
11781194
ema_unet.restore(unet.parameters())
@@ -1187,7 +1203,6 @@ def collate_fn(examples):
11871203
# Create the pipeline using the trained modules and save it.
11881204
accelerator.wait_for_everyone()
11891205
if accelerator.is_main_process:
1190-
unet = unwrap_model(unet)
11911206
if args.use_ema:
11921207
ema_unet.copy_to(unet.parameters())
11931208

@@ -1198,10 +1213,11 @@ def collate_fn(examples):
11981213
tokenizer=tokenizer_1,
11991214
tokenizer_2=tokenizer_2,
12001215
vae=vae,
1201-
unet=unet,
1216+
unet=unwrap_model(unet),
12021217
revision=args.revision,
12031218
variant=args.variant,
12041219
)
1220+
12051221
pipeline.save_pretrained(args.output_dir)
12061222

12071223
if args.push_to_hub:
@@ -1212,30 +1228,15 @@ def collate_fn(examples):
12121228
ignore_patterns=["step_*", "epoch_*"],
12131229
)
12141230

1215-
if args.validation_prompt is not None:
1216-
edited_images = []
1217-
pipeline = pipeline.to(accelerator.device)
1218-
with torch.autocast(str(accelerator.device).replace(":0", "")):
1219-
for _ in range(args.num_validation_images):
1220-
edited_images.append(
1221-
pipeline(
1222-
args.validation_prompt,
1223-
image=original_image,
1224-
num_inference_steps=20,
1225-
image_guidance_scale=1.5,
1226-
guidance_scale=7,
1227-
generator=generator,
1228-
).images[0]
1229-
)
1230-
1231-
for tracker in accelerator.trackers:
1232-
if tracker.name == "wandb":
1233-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1234-
for edited_image in edited_images:
1235-
wandb_table.add_data(
1236-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1237-
)
1238-
tracker.log({"test": wandb_table})
1231+
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
1232+
log_validation(
1233+
pipeline,
1234+
args,
1235+
accelerator,
1236+
generator,
1237+
global_step,
1238+
is_final_validation=True,
1239+
)
12391240

12401241
accelerator.end_training()
12411242

0 commit comments

Comments
 (0)