Skip to content

Recursion error when following DataParallel tutorial #836

Closed
@florisdf

Description

@florisdf

The parallelism tutorial mentions this code to forward the attributes of a DataParallel object to its wrapped module:

class MyDataParallel(nn.DataParallel):
    def __getattr__(self, name):
        return getattr(self.module, name)

This, however, leads to a recursion error, as self.module will call the same __getattr__ again. I think it should be:

class MyDataParallel(DataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

As was discussed here.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions