Skip to content

Commit 10d856a

Browse files
patrickvonplatendg845
authored andcommitted
Let's make sure that dreambooth always uploads to the Hub (huggingface#3272)
* Update Dreambooth README * Adapt all docs as well * automatically write model card * fix * make style
1 parent ffe6e92 commit 10d856a

File tree

3 files changed

+71
-13
lines changed

3 files changed

+71
-13
lines changed

docs/source/en/training/dreambooth.mdx

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \
9898
--learning_rate=5e-6 \
9999
--lr_scheduler="constant" \
100100
--lr_warmup_steps=0 \
101-
--max_train_steps=400
101+
--max_train_steps=400 \
102+
--push_to_hub
102103
```
103104
</pt>
104105
<jax>
@@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \
161162
--lr_scheduler="constant" \
162163
--lr_warmup_steps=0 \
163164
--num_class_images=200 \
164-
--max_train_steps=800
165+
--max_train_steps=800 \
166+
--push_to_hub
165167
```
166168
</pt>
167169
<jax>
@@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \
225227
--lr_scheduler="constant" \
226228
--lr_warmup_steps=0 \
227229
--num_class_images=200 \
228-
--max_train_steps=800
230+
--max_train_steps=800 \
231+
--push_to_hub
229232
```
230233
</pt>
231234
<jax>
@@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \
387390
--lr_scheduler="constant" \
388391
--lr_warmup_steps=0 \
389392
--num_class_images=200 \
390-
--max_train_steps=800
393+
--max_train_steps=800 \
394+
--push_to_hub
391395
```
392396

393397
### 12GB GPU
@@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \
418422
--lr_scheduler="constant" \
419423
--lr_warmup_steps=0 \
420424
--num_class_images=200 \
421-
--max_train_steps=800
425+
--max_train_steps=800 \
426+
--push_to_hub
422427
```
423428

424429
### 8 GB GPU
@@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \
464469
--lr_warmup_steps=0 \
465470
--num_class_images=200 \
466471
--max_train_steps=800 \
467-
--mixed_precision=fp16
472+
--mixed_precision=fp16 \
473+
--push_to_hub
468474
```
469475

470476
## Inference

examples/dreambooth/README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \
8080
--learning_rate=5e-6 \
8181
--lr_scheduler="constant" \
8282
--lr_warmup_steps=0 \
83-
--max_train_steps=400
83+
--max_train_steps=400 \
84+
--push_to_hub
8485
```
8586

8687
### Training with prior-preservation loss
@@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \
109110
--lr_scheduler="constant" \
110111
--lr_warmup_steps=0 \
111112
--num_class_images=200 \
112-
--max_train_steps=800
113+
--max_train_steps=800 \
114+
--push_to_hub
113115
```
114116

115117

@@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \
141143
--lr_scheduler="constant" \
142144
--lr_warmup_steps=0 \
143145
--num_class_images=200 \
144-
--max_train_steps=800
146+
--max_train_steps=800 \
147+
--push_to_hub
145148
```
146149

147150

@@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \
176179
--lr_scheduler="constant" \
177180
--lr_warmup_steps=0 \
178181
--num_class_images=200 \
179-
--max_train_steps=800
182+
--max_train_steps=800 \
183+
--push_to_hub
180184
```
181185

182186

@@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
218222
--lr_scheduler="constant" \
219223
--lr_warmup_steps=0 \
220224
--num_class_images=200 \
221-
--max_train_steps=800
225+
--max_train_steps=800 \
226+
--push_to_hub
222227
```
223228

224229
### Fine-tune text encoder with the UNet.
@@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \
251256
--lr_scheduler="constant" \
252257
--lr_warmup_steps=0 \
253258
--num_class_images=200 \
254-
--max_train_steps=800
259+
--max_train_steps=800 \
260+
--push_to_hub
255261
```
256262

257263
### Using DreamBooth for pipelines other than Stable Diffusion

examples/dreambooth/train_dreambooth.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,39 @@
6161
logger = get_logger(__name__)
6262

6363

64+
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
65+
img_str = ""
66+
for i, image in enumerate(images):
67+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
68+
img_str += f"![img_{i}](./image_{i}.png)\n"
69+
70+
yaml = f"""
71+
---
72+
license: creativeml-openrail-m
73+
base_model: {base_model}
74+
instance_prompt: {prompt}
75+
tags:
76+
- stable-diffusion
77+
- stable-diffusion-diffusers
78+
- text-to-image
79+
- diffusers
80+
- dreambooth
81+
inference: true
82+
---
83+
"""
84+
model_card = f"""
85+
# DreamBooth - {repo_id}
86+
87+
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
88+
You can find some example images in the following. \n
89+
{img_str}
90+
91+
DreamBooth for the text encoder was enabled: {train_text_encoder}.
92+
"""
93+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
94+
f.write(yaml + model_card)
95+
96+
6497
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
6598
logger.info(
6699
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
@@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
104137
del pipeline
105138
torch.cuda.empty_cache()
106139

140+
return images
141+
107142

108143
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
109144
text_encoder_config = PretrainedConfig.from_pretrained(
@@ -997,13 +1032,16 @@ def load_model_hook(models, input_dir):
9971032
global_step += 1
9981033

9991034
if accelerator.is_main_process:
1035+
images = []
10001036
if global_step % args.checkpointing_steps == 0:
10011037
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
10021038
accelerator.save_state(save_path)
10031039
logger.info(f"Saved state to {save_path}")
10041040

10051041
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1006-
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
1042+
images = log_validation(
1043+
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
1044+
)
10071045

10081046
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
10091047
progress_bar.set_postfix(**logs)
@@ -1024,6 +1062,14 @@ def load_model_hook(models, input_dir):
10241062
pipeline.save_pretrained(args.output_dir)
10251063

10261064
if args.push_to_hub:
1065+
save_model_card(
1066+
repo_id,
1067+
images=images,
1068+
base_model=args.pretrained_model_name_or_path,
1069+
train_text_encoder=args.train_text_encoder,
1070+
prompt=args.instance_prompt,
1071+
repo_folder=args.output_dir,
1072+
)
10271073
upload_folder(
10281074
repo_id=repo_id,
10291075
folder_path=args.output_dir,

0 commit comments

Comments
 (0)