Skip to content

Commit 7f16187

Browse files
Modularize Dreambooth LoRA SDXL inferencing during and after training (#6655)
* modularize log validation * run make style * revert import wandb * fix code quality & import wandb --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent f11b922 commit 7f16187

File tree

1 file changed

+74
-80
lines changed

1 file changed

+74
-80
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 74 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
from diffusers.utils.torch_utils import is_compiled_module
6868

6969

70+
if is_wandb_available():
71+
import wandb
72+
7073
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
7174
check_min_version("0.27.0.dev0")
7275

@@ -140,6 +143,61 @@ def save_model_card(
140143
model_card.save(os.path.join(repo_folder, "README.md"))
141144

142145

146+
def log_validation(
147+
pipeline,
148+
args,
149+
accelerator,
150+
pipeline_args,
151+
epoch,
152+
is_final_validation=False,
153+
):
154+
logger.info(
155+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
156+
f" {args.validation_prompt}."
157+
)
158+
159+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
160+
scheduler_args = {}
161+
162+
if "variance_type" in pipeline.scheduler.config:
163+
variance_type = pipeline.scheduler.config.variance_type
164+
165+
if variance_type in ["learned", "learned_range"]:
166+
variance_type = "fixed_small"
167+
168+
scheduler_args["variance_type"] = variance_type
169+
170+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
171+
172+
pipeline = pipeline.to(accelerator.device)
173+
pipeline.set_progress_bar_config(disable=True)
174+
175+
# run inference
176+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
177+
178+
with torch.cuda.amp.autocast():
179+
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
180+
181+
for tracker in accelerator.trackers:
182+
phase_name = "test" if is_final_validation else "validation"
183+
if tracker.name == "tensorboard":
184+
np_images = np.stack([np.asarray(img) for img in images])
185+
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
186+
if tracker.name == "wandb":
187+
tracker.log(
188+
{
189+
phase_name: [
190+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
191+
]
192+
}
193+
)
194+
195+
del pipeline
196+
torch.cuda.empty_cache()
197+
198+
return images
199+
200+
143201
def import_model_class_from_model_name_or_path(
144202
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
145203
):
@@ -862,7 +920,6 @@ def main(args):
862920
if args.report_to == "wandb":
863921
if not is_wandb_available():
864922
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
865-
import wandb
866923

867924
# Make one log on every process with the configuration for debugging.
868925
logging.basicConfig(
@@ -1615,10 +1672,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16151672

16161673
if accelerator.is_main_process:
16171674
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1618-
logger.info(
1619-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1620-
f" {args.validation_prompt}."
1621-
)
16221675
# create pipeline
16231676
if not args.train_text_encoder:
16241677
text_encoder_one = text_encoder_cls_one.from_pretrained(
@@ -1644,50 +1697,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16441697
torch_dtype=weight_dtype,
16451698
)
16461699

1647-
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1648-
scheduler_args = {}
1649-
1650-
if "variance_type" in pipeline.scheduler.config:
1651-
variance_type = pipeline.scheduler.config.variance_type
1652-
1653-
if variance_type in ["learned", "learned_range"]:
1654-
variance_type = "fixed_small"
1655-
1656-
scheduler_args["variance_type"] = variance_type
1657-
1658-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1659-
pipeline.scheduler.config, **scheduler_args
1660-
)
1661-
1662-
pipeline = pipeline.to(accelerator.device)
1663-
pipeline.set_progress_bar_config(disable=True)
1664-
1665-
# run inference
1666-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
16671700
pipeline_args = {"prompt": args.validation_prompt}
16681701

1669-
with torch.cuda.amp.autocast():
1670-
images = [
1671-
pipeline(**pipeline_args, generator=generator).images[0]
1672-
for _ in range(args.num_validation_images)
1673-
]
1674-
1675-
for tracker in accelerator.trackers:
1676-
if tracker.name == "tensorboard":
1677-
np_images = np.stack([np.asarray(img) for img in images])
1678-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1679-
if tracker.name == "wandb":
1680-
tracker.log(
1681-
{
1682-
"validation": [
1683-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1684-
for i, image in enumerate(images)
1685-
]
1686-
}
1687-
)
1688-
1689-
del pipeline
1690-
torch.cuda.empty_cache()
1702+
images = log_validation(
1703+
pipeline,
1704+
args,
1705+
accelerator,
1706+
pipeline_args,
1707+
epoch,
1708+
)
16911709

16921710
# Save the lora layers
16931711
accelerator.wait_for_everyone()
@@ -1733,45 +1751,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17331751
torch_dtype=weight_dtype,
17341752
)
17351753

1736-
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1737-
scheduler_args = {}
1738-
1739-
if "variance_type" in pipeline.scheduler.config:
1740-
variance_type = pipeline.scheduler.config.variance_type
1741-
1742-
if variance_type in ["learned", "learned_range"]:
1743-
variance_type = "fixed_small"
1744-
1745-
scheduler_args["variance_type"] = variance_type
1746-
1747-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1748-
17491754
# load attention processors
17501755
pipeline.load_lora_weights(args.output_dir)
17511756

17521757
# run inference
17531758
images = []
17541759
if args.validation_prompt and args.num_validation_images > 0:
1755-
pipeline = pipeline.to(accelerator.device)
1756-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1757-
images = [
1758-
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1759-
for _ in range(args.num_validation_images)
1760-
]
1761-
1762-
for tracker in accelerator.trackers:
1763-
if tracker.name == "tensorboard":
1764-
np_images = np.stack([np.asarray(img) for img in images])
1765-
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1766-
if tracker.name == "wandb":
1767-
tracker.log(
1768-
{
1769-
"test": [
1770-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1771-
for i, image in enumerate(images)
1772-
]
1773-
}
1774-
)
1760+
pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
1761+
images = log_validation(
1762+
pipeline,
1763+
args,
1764+
accelerator,
1765+
pipeline_args,
1766+
epoch,
1767+
final_validation=True,
1768+
)
17751769

17761770
if args.push_to_hub:
17771771
save_model_card(

0 commit comments

Comments
 (0)