@@ -9,19 +9,6 @@ def __init__(self, *args, **kwargs):
9
9
super ().__init__ (* args , ** kwargs )
10
10
self .gen_functors = []
11
11
12
- def input_filter (self , inp ):
13
- from pytensor .link .pytorch .dispatch import pytorch_typify
14
-
15
- return pytorch_typify (inp )
16
-
17
- def output_filter (self , var , out ):
18
- from torch import is_tensor
19
-
20
- if is_tensor (out ):
21
- return out .cpu ()
22
- else :
23
- return out
24
-
25
12
def fgraph_convert (self , fgraph , input_storage , storage_map , ** kwargs ):
26
13
from pytensor .link .pytorch .dispatch import pytorch_funcify
27
14
@@ -67,34 +54,30 @@ def __init__(self, fn, gen_functors):
67
54
self .fn = torch .compile (fn )
68
55
self .gen_functors = gen_functors .copy ()
69
56
70
- def __call__ (self , * args , ** kwargs ):
57
+ def __call__ (self , * inputs , ** kwargs ):
71
58
import pytensor .link .utils
72
59
73
60
# set attrs
74
61
for n , fn in self .gen_functors :
75
62
setattr (pytensor .link .utils , n [1 :], fn )
76
63
77
- res = self .fn (* args , ** kwargs )
64
+ # Torch does not accept numpy inputs and may return GPU objects
65
+ outs = self .fn (* (pytorch_typify (inp ) for inp in inputs ), ** kwargs )
78
66
79
67
# unset attrs
80
68
for n , _ in self .gen_functors :
81
69
if getattr (pytensor .link .utils , n [1 :], False ):
82
70
delattr (pytensor .link .utils , n [1 :])
83
71
84
- return res
72
+ return tuple ( out . cpu (). numpy () for out in outs )
85
73
86
74
def __del__ (self ):
87
75
del self .gen_functors
88
76
89
77
inner_fn = wrapper (fn , self .gen_functors )
90
78
self .gen_functors = []
91
79
92
- # Torch does not accept numpy inputs and may return GPU objects
93
- def create_outputs (* inputs , inner_fn = inner_fn ):
94
- outs = inner_fn (* (pytorch_typify (inp ) for inp in inputs ))
95
- return tuple (out .cpu ().numpy () for out in outs )
96
-
97
- return create_outputs
80
+ return inner_fn
98
81
99
82
def create_thunk_inputs (self , storage_map ):
100
83
thunk_inputs = []
0 commit comments