|
31 | 31 | from accelerate import Accelerator
|
32 | 32 | from accelerate.logging import get_logger
|
33 | 33 | from accelerate.utils import ProjectConfiguration, set_seed
|
34 |
| -from huggingface_hub import create_repo, model_info, upload_folder |
| 34 | +from huggingface_hub import create_repo, upload_folder |
35 | 35 | from packaging import version
|
36 | 36 | from PIL import Image
|
37 | 37 | from torch.utils.data import Dataset
|
@@ -589,16 +589,6 @@ def __getitem__(self, index):
|
589 | 589 | return example
|
590 | 590 |
|
591 | 591 |
|
592 |
| -def model_has_vae(args): |
593 |
| - config_file_name = os.path.join("vae", AutoencoderKL.config_name) |
594 |
| - if os.path.isdir(args.pretrained_model_name_or_path): |
595 |
| - config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) |
596 |
| - return os.path.isfile(config_file_name) |
597 |
| - else: |
598 |
| - files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings |
599 |
| - return any(file.rfilename == config_file_name for file in files_in_repo) |
600 |
| - |
601 |
| - |
602 | 592 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
|
603 | 593 | if tokenizer_max_length is not None:
|
604 | 594 | max_length = tokenizer_max_length
|
@@ -753,11 +743,13 @@ def main(args):
|
753 | 743 | text_encoder = text_encoder_cls.from_pretrained(
|
754 | 744 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
755 | 745 | )
|
756 |
| - if model_has_vae(args): |
| 746 | + try: |
757 | 747 | vae = AutoencoderKL.from_pretrained(
|
758 | 748 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
759 | 749 | )
|
760 |
| - else: |
| 750 | + except OSError: |
| 751 | + # IF does not have a VAE so let's just set it to None |
| 752 | + # We don't have to error out here |
761 | 753 | vae = None
|
762 | 754 |
|
763 | 755 | unet = UNet2DConditionModel.from_pretrained(
|
|
0 commit comments