Skip to content

Commit 2f70694

Browse files
committed
Pr comments
1 parent b6ce485 commit 2f70694

File tree

2 files changed

+6
-23
lines changed

2 files changed

+6
-23
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def elemwise_fn(*inputs):
198198

199199
out_shape = bcasted_inputs[0].size()
200200
out_size = out_shape.numel()
201-
raveled_outputs = [torch.zeros(out_size) for out in node.outputs]
201+
raveled_outputs = [torch.empty(out_size) for out in node.outputs]
202202

203203
for i in range(out_size):
204204
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))

pytensor/link/pytorch/linker.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,6 @@ def __init__(self, *args, **kwargs):
99
super().__init__(*args, **kwargs)
1010
self.gen_functors = []
1111

12-
def input_filter(self, inp):
13-
from pytensor.link.pytorch.dispatch import pytorch_typify
14-
15-
return pytorch_typify(inp)
16-
17-
def output_filter(self, var, out):
18-
from torch import is_tensor
19-
20-
if is_tensor(out):
21-
return out.cpu()
22-
else:
23-
return out
24-
2512
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2613
from pytensor.link.pytorch.dispatch import pytorch_funcify
2714

@@ -67,34 +54,30 @@ def __init__(self, fn, gen_functors):
6754
self.fn = torch.compile(fn)
6855
self.gen_functors = gen_functors.copy()
6956

70-
def __call__(self, *args, **kwargs):
57+
def __call__(self, *inputs, **kwargs):
7158
import pytensor.link.utils
7259

7360
# set attrs
7461
for n, fn in self.gen_functors:
7562
setattr(pytensor.link.utils, n[1:], fn)
7663

77-
res = self.fn(*args, **kwargs)
64+
# Torch does not accept numpy inputs and may return GPU objects
65+
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
7866

7967
# unset attrs
8068
for n, _ in self.gen_functors:
8169
if getattr(pytensor.link.utils, n[1:], False):
8270
delattr(pytensor.link.utils, n[1:])
8371

84-
return res
72+
return tuple(out.cpu().numpy() for out in outs)
8573

8674
def __del__(self):
8775
del self.gen_functors
8876

8977
inner_fn = wrapper(fn, self.gen_functors)
9078
self.gen_functors = []
9179

92-
# Torch does not accept numpy inputs and may return GPU objects
93-
def create_outputs(*inputs, inner_fn=inner_fn):
94-
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
95-
return tuple(out.cpu().numpy() for out in outs)
96-
97-
return create_outputs
80+
return inner_fn
9881

9982
def create_thunk_inputs(self, storage_map):
10083
thunk_inputs = []

0 commit comments

Comments
 (0)