27
27
import torch
28
28
import torch .utils .checkpoint
29
29
import transformers
30
- from datasets import load_dataset
30
+ from datasets import load_dataset , load_from_disk
31
31
from flax import jax_utils
32
32
from flax .core .frozen_dict import unfreeze
33
33
from flax .training import train_state
34
34
from flax .training .common_utils import shard
35
35
from huggingface_hub import create_repo , upload_folder
36
- from PIL import Image
36
+ from PIL import Image , PngImagePlugin
37
37
from torch .utils .data import IterableDataset
38
38
from torchvision import transforms
39
39
from tqdm .auto import tqdm
49
49
from diffusers .utils import check_min_version , is_wandb_available
50
50
51
51
52
+ # To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
53
+ # see more https://github.com/python-pillow/Pillow/issues/5610
54
+ LARGE_ENOUGH_NUMBER = 100
55
+ PngImagePlugin .MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2 )
56
+
52
57
if is_wandb_available ():
53
58
import wandb
54
59
@@ -246,6 +251,12 @@ def parse_args():
246
251
default = None ,
247
252
help = "Total number of training steps to perform." ,
248
253
)
254
+ parser .add_argument (
255
+ "--checkpointing_steps" ,
256
+ type = int ,
257
+ default = 5000 ,
258
+ help = ("Save a checkpoint of the training state every X updates." ),
259
+ )
249
260
parser .add_argument (
250
261
"--learning_rate" ,
251
262
type = float ,
@@ -344,9 +355,17 @@ def parse_args():
344
355
type = str ,
345
356
default = None ,
346
357
help = (
347
- "A folder containing the training data. Folder contents must follow the structure described in"
348
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
349
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
358
+ "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
359
+ "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
360
+ "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
361
+ ),
362
+ )
363
+ parser .add_argument (
364
+ "--load_from_disk" ,
365
+ action = "store_true" ,
366
+ help = (
367
+ "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
368
+ "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
350
369
),
351
370
)
352
371
parser .add_argument (
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
478
497
)
479
498
else :
480
499
if args .train_data_dir is not None :
481
- dataset = load_dataset (
482
- args .train_data_dir ,
483
- cache_dir = args .cache_dir ,
484
- )
500
+ if args .load_from_disk :
501
+ dataset = load_from_disk (
502
+ args .train_data_dir ,
503
+ )
504
+ else :
505
+ dataset = load_dataset (
506
+ args .train_data_dir ,
507
+ cache_dir = args .cache_dir ,
508
+ )
485
509
# See more about loading custom images at
486
510
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
487
511
@@ -545,6 +569,7 @@ def tokenize_captions(examples, is_train=True):
545
569
image_transforms = transforms .Compose (
546
570
[
547
571
transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
572
+ transforms .CenterCrop (args .resolution ),
548
573
transforms .ToTensor (),
549
574
transforms .Normalize ([0.5 ], [0.5 ]),
550
575
]
@@ -553,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
553
578
conditioning_image_transforms = transforms .Compose (
554
579
[
555
580
transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
581
+ transforms .CenterCrop (args .resolution ),
556
582
transforms .ToTensor (),
557
583
]
558
584
)
@@ -981,6 +1007,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
981
1007
"train/loss" : jax_utils .unreplicate (train_metric )["loss" ],
982
1008
}
983
1009
)
1010
+ if global_step % args .checkpointing_steps == 0 and jax .process_index () == 0 :
1011
+ controlnet .save_pretrained (
1012
+ f"{ args .output_dir } /{ global_step } " ,
1013
+ params = get_params_to_save (state .params ),
1014
+ )
984
1015
985
1016
train_metric = jax_utils .unreplicate (train_metric )
986
1017
train_step_progress_bar .close ()
0 commit comments