Closed
Description
The parallelism tutorial mentions this code to forward the attributes of a DataParallel
object to its wrapped module:
class MyDataParallel(nn.DataParallel):
def __getattr__(self, name):
return getattr(self.module, name)
This, however, leads to a recursion error, as self.module
will call the same __getattr__
again. I think it should be:
class MyDataParallel(DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
As was discussed here.