Skip to content

Commit 0bde5e4

Browse files
patrickvonplatenhari10599
authored andcommitted
Make dreambooth lora more robust to orig unet (huggingface#3462)
* Make dreambooth lora more robust to orig unet * up
1 parent 3271531 commit 0bde5e4

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from accelerate import Accelerator
3232
from accelerate.logging import get_logger
3333
from accelerate.utils import ProjectConfiguration, set_seed
34-
from huggingface_hub import create_repo, model_info, upload_folder
34+
from huggingface_hub import create_repo, upload_folder
3535
from packaging import version
3636
from PIL import Image
3737
from torch.utils.data import Dataset
@@ -589,16 +589,6 @@ def __getitem__(self, index):
589589
return example
590590

591591

592-
def model_has_vae(args):
593-
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
594-
if os.path.isdir(args.pretrained_model_name_or_path):
595-
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
596-
return os.path.isfile(config_file_name)
597-
else:
598-
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
599-
return any(file.rfilename == config_file_name for file in files_in_repo)
600-
601-
602592
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
603593
if tokenizer_max_length is not None:
604594
max_length = tokenizer_max_length
@@ -753,11 +743,13 @@ def main(args):
753743
text_encoder = text_encoder_cls.from_pretrained(
754744
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
755745
)
756-
if model_has_vae(args):
746+
try:
757747
vae = AutoencoderKL.from_pretrained(
758748
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
759749
)
760-
else:
750+
except OSError:
751+
# IF does not have a VAE so let's just set it to None
752+
# We don't have to error out here
761753
vae = None
762754

763755
unet = UNet2DConditionModel.from_pretrained(

0 commit comments

Comments
 (0)