Skip to content

Commit 9191054

Browse files
Wauplintestbot
authored andcommitted
Use upload_folder in training scripts (huggingface#2934)
use upload folder in training scripts Co-authored-by: testbot <[email protected]>
1 parent 238039e commit 9191054

File tree

20 files changed

+271
-553
lines changed

20 files changed

+271
-553
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import os
2020
import random
2121
from pathlib import Path
22-
from typing import Optional
2322

2423
import accelerate
2524
import numpy as np
@@ -31,7 +30,7 @@
3130
from accelerate.logging import get_logger
3231
from accelerate.utils import ProjectConfiguration, set_seed
3332
from datasets import load_dataset
34-
from huggingface_hub import HfFolder, Repository, create_repo, whoami
33+
from huggingface_hub import create_repo, upload_folder
3534
from packaging import version
3635
from PIL import Image
3736
from torchvision import transforms
@@ -661,16 +660,6 @@ def collate_fn(examples):
661660
}
662661

663662

664-
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
665-
if token is None:
666-
token = HfFolder.get_token()
667-
if organization is None:
668-
username = whoami(token)["name"]
669-
return f"{username}/{model_id}"
670-
else:
671-
return f"{organization}/{model_id}"
672-
673-
674663
def main(args):
675664
logging_dir = Path(args.output_dir, args.logging_dir)
676665

@@ -704,22 +693,14 @@ def main(args):
704693

705694
# Handle the repository creation
706695
if accelerator.is_main_process:
707-
if args.push_to_hub:
708-
if args.hub_model_id is None:
709-
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
710-
else:
711-
repo_name = args.hub_model_id
712-
create_repo(repo_name, exist_ok=True, token=args.hub_token)
713-
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
714-
715-
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
716-
if "step_*" not in gitignore:
717-
gitignore.write("step_*\n")
718-
if "epoch_*" not in gitignore:
719-
gitignore.write("epoch_*\n")
720-
elif args.output_dir is not None:
696+
if args.output_dir is not None:
721697
os.makedirs(args.output_dir, exist_ok=True)
722698

699+
if args.push_to_hub:
700+
repo_id = create_repo(
701+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
702+
).repo_id
703+
723704
# Load the tokenizer
724705
if args.tokenizer_name:
725706
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
@@ -1053,7 +1034,12 @@ def load_model_hook(models, input_dir):
10531034
controlnet.save_pretrained(args.output_dir)
10541035

10551036
if args.push_to_hub:
1056-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1037+
upload_folder(
1038+
repo_id=repo_id,
1039+
folder_path=args.output_dir,
1040+
commit_message="End of training",
1041+
ignore_patterns=["step_*", "epoch_*"],
1042+
)
10571043

10581044
accelerator.end_training()
10591045

examples/controlnet/train_controlnet_flax.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import os
2020
import random
2121
from pathlib import Path
22-
from typing import Optional
2322

2423
import jax
2524
import jax.numpy as jnp
@@ -33,7 +32,7 @@
3332
from flax.core.frozen_dict import unfreeze
3433
from flax.training import train_state
3534
from flax.training.common_utils import shard
36-
from huggingface_hub import HfFolder, Repository, create_repo, whoami
35+
from huggingface_hub import create_repo, upload_folder
3736
from PIL import Image
3837
from torch.utils.data import IterableDataset
3938
from torchvision import transforms
@@ -148,7 +147,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
148147
return image_logs
149148

150149

151-
def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None):
150+
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
152151
img_str = ""
153152
for i, log in enumerate(image_logs):
154153
images = log["images"]
@@ -174,7 +173,7 @@ def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None
174173
---
175174
"""
176175
model_card = f"""
177-
# controlnet- {repo_name}
176+
# controlnet- {repo_id}
178177
179178
These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
180179
{img_str}
@@ -612,16 +611,6 @@ def collate_fn(examples):
612611
return batch
613612

614613

615-
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
616-
if token is None:
617-
token = HfFolder.get_token()
618-
if organization is None:
619-
username = whoami(token)["name"]
620-
return f"{username}/{model_id}"
621-
else:
622-
return f"{organization}/{model_id}"
623-
624-
625614
def get_params_to_save(params):
626615
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
627616

@@ -656,22 +645,14 @@ def main():
656645

657646
# Handle the repository creation
658647
if jax.process_index() == 0:
659-
if args.push_to_hub:
660-
if args.hub_model_id is None:
661-
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
662-
else:
663-
repo_name = args.hub_model_id
664-
repo_url = create_repo(repo_name, exist_ok=True, token=args.hub_token)
665-
repo = Repository(args.output_dir, clone_from=repo_url, token=args.hub_token)
666-
667-
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
668-
if "step_*" not in gitignore:
669-
gitignore.write("step_*\n")
670-
if "epoch_*" not in gitignore:
671-
gitignore.write("epoch_*\n")
672-
elif args.output_dir is not None:
648+
if args.output_dir is not None:
673649
os.makedirs(args.output_dir, exist_ok=True)
674650

651+
if args.push_to_hub:
652+
repo_id = create_repo(
653+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
654+
).repo_id
655+
675656
# Load the tokenizer and add the placeholder token as a additional special token
676657
if args.tokenizer_name:
677658
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
@@ -1020,12 +1001,17 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
10201001

10211002
if args.push_to_hub:
10221003
save_model_card(
1023-
repo_name,
1004+
repo_id,
10241005
image_logs=image_logs,
10251006
base_model=args.pretrained_model_name_or_path,
10261007
repo_folder=args.output_dir,
10271008
)
1028-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1009+
upload_folder(
1010+
repo_id=repo_id,
1011+
folder_path=args.output_dir,
1012+
commit_message="End of training",
1013+
ignore_patterns=["step_*", "epoch_*"],
1014+
)
10291015

10301016

10311017
if __name__ == "__main__":

examples/dreambooth/train_dreambooth.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import os
2222
import warnings
2323
from pathlib import Path
24-
from typing import Optional
2524

2625
import accelerate
2726
import numpy as np
@@ -32,7 +31,7 @@
3231
from accelerate import Accelerator
3332
from accelerate.logging import get_logger
3433
from accelerate.utils import ProjectConfiguration, set_seed
35-
from huggingface_hub import HfFolder, Repository, create_repo, whoami
34+
from huggingface_hub import create_repo, upload_folder
3635
from packaging import version
3736
from PIL import Image
3837
from torch.utils.data import Dataset
@@ -575,16 +574,6 @@ def __getitem__(self, index):
575574
return example
576575

577576

578-
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
579-
if token is None:
580-
token = HfFolder.get_token()
581-
if organization is None:
582-
username = whoami(token)["name"]
583-
return f"{username}/{model_id}"
584-
else:
585-
return f"{organization}/{model_id}"
586-
587-
588577
def main(args):
589578
logging_dir = Path(args.output_dir, args.logging_dir)
590579

@@ -677,22 +666,14 @@ def main(args):
677666

678667
# Handle the repository creation
679668
if accelerator.is_main_process:
680-
if args.push_to_hub:
681-
if args.hub_model_id is None:
682-
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
683-
else:
684-
repo_name = args.hub_model_id
685-
create_repo(repo_name, exist_ok=True, token=args.hub_token)
686-
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
687-
688-
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
689-
if "step_*" not in gitignore:
690-
gitignore.write("step_*\n")
691-
if "epoch_*" not in gitignore:
692-
gitignore.write("epoch_*\n")
693-
elif args.output_dir is not None:
669+
if args.output_dir is not None:
694670
os.makedirs(args.output_dir, exist_ok=True)
695671

672+
if args.push_to_hub:
673+
repo_id = create_repo(
674+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
675+
).repo_id
676+
696677
# Load the tokenizer
697678
if args.tokenizer_name:
698679
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
@@ -1043,7 +1024,12 @@ def load_model_hook(models, input_dir):
10431024
pipeline.save_pretrained(args.output_dir)
10441025

10451026
if args.push_to_hub:
1046-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1027+
upload_folder(
1028+
repo_id=repo_id,
1029+
folder_path=args.output_dir,
1030+
commit_message="End of training",
1031+
ignore_patterns=["step_*", "epoch_*"],
1032+
)
10471033

10481034
accelerator.end_training()
10491035

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import os
2121
import warnings
2222
from pathlib import Path
23-
from typing import Optional
2423

2524
import numpy as np
2625
import torch
@@ -30,7 +29,7 @@
3029
from accelerate import Accelerator
3130
from accelerate.logging import get_logger
3231
from accelerate.utils import ProjectConfiguration, set_seed
33-
from huggingface_hub import HfFolder, Repository, create_repo, whoami
32+
from huggingface_hub import create_repo, upload_folder
3433
from packaging import version
3534
from PIL import Image
3635
from torch.utils.data import Dataset
@@ -59,7 +58,7 @@
5958
logger = get_logger(__name__)
6059

6160

62-
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
61+
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
6362
img_str = ""
6463
for i, image in enumerate(images):
6564
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@@ -80,7 +79,7 @@ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_fol
8079
---
8180
"""
8281
model_card = f"""
83-
# LoRA DreamBooth - {repo_name}
82+
# LoRA DreamBooth - {repo_id}
8483
8584
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
8685
{img_str}
@@ -528,16 +527,6 @@ def __getitem__(self, index):
528527
return example
529528

530529

531-
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
532-
if token is None:
533-
token = HfFolder.get_token()
534-
if organization is None:
535-
username = whoami(token)["name"]
536-
return f"{username}/{model_id}"
537-
else:
538-
return f"{organization}/{model_id}"
539-
540-
541530
def main(args):
542531
logging_dir = Path(args.output_dir, args.logging_dir)
543532

@@ -625,23 +614,14 @@ def main(args):
625614

626615
# Handle the repository creation
627616
if accelerator.is_main_process:
628-
if args.push_to_hub:
629-
if args.hub_model_id is None:
630-
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
631-
else:
632-
repo_name = args.hub_model_id
633-
634-
create_repo(repo_name, exist_ok=True, token=args.hub_token)
635-
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
636-
637-
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
638-
if "step_*" not in gitignore:
639-
gitignore.write("step_*\n")
640-
if "epoch_*" not in gitignore:
641-
gitignore.write("epoch_*\n")
642-
elif args.output_dir is not None:
617+
if args.output_dir is not None:
643618
os.makedirs(args.output_dir, exist_ok=True)
644619

620+
if args.push_to_hub:
621+
repo_id = create_repo(
622+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
623+
).repo_id
624+
645625
# Load the tokenizer
646626
if args.tokenizer_name:
647627
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
@@ -1027,13 +1007,18 @@ def main(args):
10271007

10281008
if args.push_to_hub:
10291009
save_model_card(
1030-
repo_name,
1010+
repo_id,
10311011
images=images,
10321012
base_model=args.pretrained_model_name_or_path,
10331013
prompt=args.instance_prompt,
10341014
repo_folder=args.output_dir,
10351015
)
1036-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1016+
upload_folder(
1017+
repo_id=repo_id,
1018+
folder_path=args.output_dir,
1019+
commit_message="End of training",
1020+
ignore_patterns=["step_*", "epoch_*"],
1021+
)
10371022

10381023
accelerator.end_training()
10391024

0 commit comments

Comments
 (0)