-
Notifications
You must be signed in to change notification settings - Fork 6k
update flax controlnet training script #2951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,13 +28,13 @@ | |
import torch | ||
import torch.utils.checkpoint | ||
import transformers | ||
from datasets import load_dataset | ||
from datasets import load_dataset, load_from_disk | ||
from flax import jax_utils | ||
from flax.core.frozen_dict import unfreeze | ||
from flax.training import train_state | ||
from flax.training.common_utils import shard | ||
from huggingface_hub import HfFolder, Repository, create_repo, whoami | ||
from PIL import Image | ||
from PIL import Image, PngImagePlugin | ||
from torch.utils.data import IterableDataset | ||
from torchvision import transforms | ||
from tqdm.auto import tqdm | ||
|
@@ -50,6 +50,9 @@ | |
from diffusers.utils import check_min_version, is_wandb_available | ||
|
||
|
||
LARGE_ENOUGH_NUMBER = 100 | ||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) | ||
|
||
if is_wandb_available(): | ||
import wandb | ||
|
||
|
@@ -247,6 +250,12 @@ def parse_args(): | |
default=None, | ||
help="Total number of training steps to perform.", | ||
) | ||
parser.add_argument( | ||
"--checkpointing_steps", | ||
type=int, | ||
default=5000, | ||
help=("Save a checkpoint of the training state every X updates."), | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
|
@@ -345,11 +354,16 @@ def parse_args(): | |
type=str, | ||
default=None, | ||
help=( | ||
"A folder containing the training data. Folder contents must follow the structure described in" | ||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" | ||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified." | ||
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder." | ||
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ." | ||
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified." | ||
), | ||
) | ||
parser.add_argument( | ||
"--load_from_disk", | ||
action="store_true", | ||
help="If True, will load a dataset that was previously saved using [`save_to_disk`] from `--train_data_dir`", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide a link to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the nit: I think we can just do |
||
parser.add_argument( | ||
"--image_column", type=str, default="image", help="The column of the dataset containing the target image." | ||
) | ||
|
@@ -478,14 +492,16 @@ def make_train_dataset(args, tokenizer, batch_size=None): | |
streaming=args.streaming, | ||
) | ||
else: | ||
data_files = {} | ||
if args.train_data_dir is not None: | ||
data_files["train"] = os.path.join(args.train_data_dir, "**") | ||
dataset = load_dataset( | ||
"imagefolder", | ||
data_files=data_files, | ||
cache_dir=args.cache_dir, | ||
) | ||
if args.load_from_disk: | ||
dataset = load_from_disk( | ||
args.train_data_dir, | ||
) | ||
else: | ||
dataset = load_dataset( | ||
args.train_data_dir, | ||
cache_dir=args.cache_dir, | ||
) | ||
# See more about loading custom images at | ||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder | ||
|
||
|
@@ -549,6 +565,7 @@ def tokenize_captions(examples, is_train=True): | |
image_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.CenterCrop(args.resolution), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
|
@@ -557,6 +574,7 @@ def tokenize_captions(examples, is_train=True): | |
conditioning_image_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.CenterCrop(args.resolution), | ||
transforms.ToTensor(), | ||
] | ||
) | ||
|
@@ -1003,6 +1021,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng): | |
"train/loss": jax_utils.unreplicate(train_metric)["loss"], | ||
} | ||
) | ||
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: | ||
controlnet.save_pretrained( | ||
f"{args.output_dir}/{global_step}", | ||
params=get_params_to_save(state.params), | ||
) | ||
Comment on lines
+1011
to
+1014
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool! |
||
|
||
train_metric = jax_utils.unreplicate(train_metric) | ||
train_step_progress_bar.close() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment for the users to know why this needs to be set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly does this change do here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment there - to prevent an error I would get when working with coyo700m "Decompressed Data Too Large"