Description
This whole thing (i.e., calling out.cpu()
) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), because np.asarray
works with it, and I guess it doesn't work for torch tensors.
pytensor/pytensor/link/pytorch/linker.py
Line 16 in 7b13a95
This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.
Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function
.
pytensor/pytensor/compile/function/types.py
Lines 1009 to 1017 in 7b13a95
Originally posted by @ricardoV94 in #1032 (comment)