Skip to content

Don't force .cpu() on all PyTorch outputs #1052

Open
@ricardoV94

Description

@ricardoV94

This whole thing (i.e., calling out.cpu()) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), because np.asarray works with it, and I guess it doesn't work for torch tensors.

This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]

Originally posted by @ricardoV94 in #1032 (comment)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions