Skip to content

Commit ee20d1f

Browse files
authored
update flax controlnet training script (#2951)
* load_from_disk + checkpointing_steps * apply feedback
1 parent 0d0fa2a commit ee20d1f

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
import torch
2828
import torch.utils.checkpoint
2929
import transformers
30-
from datasets import load_dataset
30+
from datasets import load_dataset, load_from_disk
3131
from flax import jax_utils
3232
from flax.core.frozen_dict import unfreeze
3333
from flax.training import train_state
3434
from flax.training.common_utils import shard
3535
from huggingface_hub import create_repo, upload_folder
36-
from PIL import Image
36+
from PIL import Image, PngImagePlugin
3737
from torch.utils.data import IterableDataset
3838
from torchvision import transforms
3939
from tqdm.auto import tqdm
@@ -49,6 +49,11 @@
4949
from diffusers.utils import check_min_version, is_wandb_available
5050

5151

52+
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
53+
# see more https://github.com/python-pillow/Pillow/issues/5610
54+
LARGE_ENOUGH_NUMBER = 100
55+
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
56+
5257
if is_wandb_available():
5358
import wandb
5459

@@ -246,6 +251,12 @@ def parse_args():
246251
default=None,
247252
help="Total number of training steps to perform.",
248253
)
254+
parser.add_argument(
255+
"--checkpointing_steps",
256+
type=int,
257+
default=5000,
258+
help=("Save a checkpoint of the training state every X updates."),
259+
)
249260
parser.add_argument(
250261
"--learning_rate",
251262
type=float,
@@ -344,9 +355,17 @@ def parse_args():
344355
type=str,
345356
default=None,
346357
help=(
347-
"A folder containing the training data. Folder contents must follow the structure described in"
348-
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
349-
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
358+
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
359+
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
360+
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
361+
),
362+
)
363+
parser.add_argument(
364+
"--load_from_disk",
365+
action="store_true",
366+
help=(
367+
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
368+
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
350369
),
351370
)
352371
parser.add_argument(
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
478497
)
479498
else:
480499
if args.train_data_dir is not None:
481-
dataset = load_dataset(
482-
args.train_data_dir,
483-
cache_dir=args.cache_dir,
484-
)
500+
if args.load_from_disk:
501+
dataset = load_from_disk(
502+
args.train_data_dir,
503+
)
504+
else:
505+
dataset = load_dataset(
506+
args.train_data_dir,
507+
cache_dir=args.cache_dir,
508+
)
485509
# See more about loading custom images at
486510
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
487511

@@ -545,6 +569,7 @@ def tokenize_captions(examples, is_train=True):
545569
image_transforms = transforms.Compose(
546570
[
547571
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
572+
transforms.CenterCrop(args.resolution),
548573
transforms.ToTensor(),
549574
transforms.Normalize([0.5], [0.5]),
550575
]
@@ -553,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
553578
conditioning_image_transforms = transforms.Compose(
554579
[
555580
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
581+
transforms.CenterCrop(args.resolution),
556582
transforms.ToTensor(),
557583
]
558584
)
@@ -981,6 +1007,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
9811007
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
9821008
}
9831009
)
1010+
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
1011+
controlnet.save_pretrained(
1012+
f"{args.output_dir}/{global_step}",
1013+
params=get_params_to_save(state.params),
1014+
)
9841015

9851016
train_metric = jax_utils.unreplicate(train_metric)
9861017
train_step_progress_bar.close()

0 commit comments

Comments
 (0)