Skip to content

KeyError: 'self' in save_hyperparameters() when custom metaclass used #20697

Open
@aditya0by0

Description

@aditya0by0

Bug description

I'm working with LightningDataModule and wanted to ensure that a method (_after_init) runs only once after full initialization, regardless of subclassing. For that, I implemented a custom metaclass (_InitMeta) that overrides __call__ to invoke _after_init after the instance is fully created.

When using create an instance of final subclass, I encounter a KeyError: 'self' inside save_hyperparameters().

What version are you seeing the problem on?

v2.5, v2.1

How to reproduce the bug

from typing import Any

from lightning import LightningDataModule


class _InitMeta(type):
    def __call__(
        cls: Any, *args: Any, **kwargs: Any
    ) -> Any:
        instance = super().__call__(*args, **kwargs)  # Create the instance
        if hasattr(instance, "_after_init"):
            instance._after_init(**kwargs)  # Call the method if defined
        return instance


class A(LightningDataModule, metaclass=_InitMeta):
    def __init__(self, *args, **kwargs):
        self.save_hyperparameters()
        self.a = 1
        self.b = 2
        super().__init__(*args, **kwargs)

    def print_ab(self, **kwargs: Any):
        print("in print ab")
        if kwargs.get("flag", False):
            print("flag is set to False")
            print("some other logic")
        else:
            print(self.a, self.b)

    def _after_init(self, **kwargs):
        """Called only once after full initialization."""
        self.print_ab(**kwargs)


class B(A):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.a += 1
        self.b += 2


class C(B):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.a += 1
        self.b += 2


if __name__ == "__main__":
    print("Creating C instance:")
    c = C()  # Should print 3, 6 only once

    print("\nCreating B instance:")
    b = B()  # Should print 2, 4 only once

    print("\nCreating A instance:")
    a = A()  # Should print 1, 2 only once

Error messages and logs

Creating C instance:
Traceback (most recent call last):
  File "G:\github-aditya0by0\python-chebai\test.py", line 48, in <module>
    c = C()  # Should print 3, 6 only once
  File "G:\github-aditya0by0\python-chebai\test.py", line 10, in __call__
    instance = super().__call__(*args, **kwargs)  # Create the instance
  File "G:\github-aditya0by0\python-chebai\test.py", line 41, in __init__
    super().__init__(**kwargs)
  File "G:\github-aditya0by0\python-chebai\test.py", line 34, in __init__
    super().__init__(**kwargs)
  File "G:\github-aditya0by0\python-chebai\test.py", line 18, in __init__
    self.save_hyperparameters()
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\core\mixins\hparams_mixin.py", line 112, in save_hyperparameters
    save_hyperparameters(self, *args, ignore=ignore, frame=frame)
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 165, in save_hyperparameters
    for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
    return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
    return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
    return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 131, in collect_init_args
    local_self, local_args = _get_init_args(frame)
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 97, in _get_init_args
    local_args = {k: local_vars[k] for k in init_parameters}
  File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 97, in <dictcomp>
    local_args = {k: local_vars[k] for k in init_parameters}
KeyError: 'self'

Environment

Current environment
- PyTorch Lightning: 2.1.2
- Python: 3.10.14
- Torch: 2.5.1

More info

This is acknowledged as bug by @jsbueno on https://stackoverflow.com/questions/79554986/keyerror-self-in-save-hyperparameters-when-custom-metaclass-used-pytorch

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions