Skip to content

saving and loading checkpoints do not work on train_dreambooth.py when using text_encoder #3296

Closed
@amitgurintelcom

Description

@amitgurintelcom

Describe the bug

There are 2 issues when trying to start from a checkpoint, when using --train_text_encoder.
First issue is #2480 . I wrote there how to fix it in train_dreambooth.py. It fixes how the checkpoint is saved.
Second issue is similar. It is in load_model_hook (how the text_encoder checkpoint is loaded).
type(model) is not equal to type(text_encoder) in my env.

In current code, it is written:
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()

            if type(model) == type(text_encoder):
                # load transformers style into model
                load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
                model.config = load_model.config
            else:
                # load diffusers style into model
                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
                model.register_to_config(**load_model.config)

            model.load_state_dict(load_model.state_dict())
            del load_model

The fix should be:
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()

            if "CLIPTextModel" in str(type(model)):
                # load transformers style into model
                load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
                model.config = load_model.config
            else:
                # load diffusers style into model
                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
                model.register_to_config(**load_model.config)

            model.load_state_dict(load_model.state_dict())
            del load_model

Reproduction

Command line:
accelerate launch train_dreambooth.py
--mixed_precision="fp16"
--revision="fp16"
--instance_data_dir="./XYZ"
--output_dir="checkpoints"
--instance_prompt="XYZ"
--resolution=512
--train_batch_size=2
--gradient_accumulation_steps=1
--learning_rate=1.5e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=1200
--num_class_images=260
--with_prior_preservation
--prior_loss_weight=1.0
--class_prompt="man"
--class_data_dir="./man"
--train_text_encoder
--validation_prompt="Portrait of XYZ"
--validation_steps=100
--checkpointing_steps=200
--pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4"
--resume_from_checkpoint="latest"

Logs

AttributeError: 'CLIPTextModel' object has no attribute 'register_to_config'
Traceback (most recent call last):
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 3489, in <module>
    main()
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 3482, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 2510, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 2517, in _exec
    globals = pydevd_runpy.run_path(file, globals, '__main__')
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/agur/.vscode-server/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "train_dreambooth.py", line 1085, in <module>
    main(args)
  File "train_dreambooth.py", line 944, in main
    accelerator.load_state(os.path.join(args.output_dir, path))
  File "/home/agur/.local/lib/python3.8/site-packages/accelerate/accelerator.py", line 2394, in load_state
    hook(models, input_dir)
  File "train_dreambooth.py", line 759, in load_model_hook
    model.register_to_config(**load_model.config)
  File "/home/agur/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'CLIPTextModel' object has no attribute 'register_to_config'

System Info

diffusers-cli env:

  • diffusers version: 0.17.0.dev0
  • Platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 2.0.0+cu117 (True)
  • Huggingface_hub version: 0.13.3
  • Transformers version: 4.28.1
  • Accelerate version: 0.18.0
  • xFormers version: not installed
  • Using GPU in script?: yes. nvidia-driver-530
  • Using distributed or parallel set-up in script?: yes. Data parallel

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions