Skip to content

Commit 4da5d18

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored andcommitted
allow use custom local dataset for controlnet training scripts (huggingface#2928)
use custom local datset Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9191054 commit 4da5d18

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,16 +542,13 @@ def make_train_dataset(args, tokenizer, accelerator):
542542
cache_dir=args.cache_dir,
543543
)
544544
else:
545-
data_files = {}
546545
if args.train_data_dir is not None:
547-
data_files["train"] = os.path.join(args.train_data_dir, "**")
548-
dataset = load_dataset(
549-
"imagefolder",
550-
data_files=data_files,
551-
cache_dir=args.cache_dir,
552-
)
546+
dataset = load_dataset(
547+
args.train_data_dir,
548+
cache_dir=args.cache_dir,
549+
)
553550
# See more about loading custom images at
554-
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
551+
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
555552

556553
# Preprocessing the datasets.
557554
# We need to tokenize inputs and targets.

examples/controlnet/train_controlnet_flax.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,13 @@ def make_train_dataset(args, tokenizer, batch_size=None):
477477
streaming=args.streaming,
478478
)
479479
else:
480-
data_files = {}
481480
if args.train_data_dir is not None:
482-
data_files["train"] = os.path.join(args.train_data_dir, "**")
483-
dataset = load_dataset(
484-
"imagefolder",
485-
data_files=data_files,
486-
cache_dir=args.cache_dir,
487-
)
481+
dataset = load_dataset(
482+
args.train_data_dir,
483+
cache_dir=args.cache_dir,
484+
)
488485
# See more about loading custom images at
489-
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
486+
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
490487

491488
# Preprocessing the datasets.
492489
# We need to tokenize inputs and targets.

0 commit comments

Comments
 (0)