Skip to content

Commit 5c63ee7

Browse files
authored
Allow passing static shape to tensor creation helpers (#118)
* Allow passing static shape to tensor creation helpers * Also default dtype to "floatX" when using `tensor` * Make tensor API similar to that of other variable constructors * Name is now the only optional non-keyword argument for all constructors
1 parent 43d91d0 commit 5c63ee7

26 files changed

+419
-130
lines changed

pytensor/sparse/basic.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -3451,7 +3451,12 @@ def make_node(self, a, b):
34513451
return Apply(
34523452
self,
34533453
[a, b],
3454-
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
3454+
[
3455+
tensor(
3456+
dtype=dtype_out,
3457+
shape=(None, 1 if b.type.shape[1] == 1 else None),
3458+
)
3459+
],
34553460
)
34563461

34573462
def perform(self, node, inputs, outputs):
@@ -3582,7 +3587,9 @@ class StructuredDotGradCSC(COp):
35823587

35833588
def make_node(self, a_indices, a_indptr, b, g_ab):
35843589
return Apply(
3585-
self, [a_indices, a_indptr, b, g_ab], [tensor(g_ab.dtype, shape=(None,))]
3590+
self,
3591+
[a_indices, a_indptr, b, g_ab],
3592+
[tensor(dtype=g_ab.dtype, shape=(None,))],
35863593
)
35873594

35883595
def perform(self, node, inputs, outputs):
@@ -3716,7 +3723,7 @@ class StructuredDotGradCSR(COp):
37163723

37173724
def make_node(self, a_indices, a_indptr, b, g_ab):
37183725
return Apply(
3719-
self, [a_indices, a_indptr, b, g_ab], [tensor(b.dtype, shape=(None,))]
3726+
self, [a_indices, a_indptr, b, g_ab], [tensor(dtype=b.dtype, shape=(None,))]
37203727
)
37213728

37223729
def perform(self, node, inputs, outputs):

pytensor/sparse/rewriting.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,11 @@ def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
270270
r = Apply(
271271
self,
272272
[a_val, a_ind, a_ptr, a_nrows, b],
273-
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
273+
[
274+
tensor(
275+
dtype=dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None)
276+
)
277+
],
274278
)
275279
return r
276280

@@ -465,7 +469,12 @@ def make_node(self, a_val, a_ind, a_ptr, b):
465469
r = Apply(
466470
self,
467471
[a_val, a_ind, a_ptr, b],
468-
[tensor(self.dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
472+
[
473+
tensor(
474+
dtype=self.dtype_out,
475+
shape=(None, 1 if b.type.shape[1] == 1 else None),
476+
)
477+
],
469478
)
470479
return r
471480

@@ -705,7 +714,11 @@ def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
705714
r = Apply(
706715
self,
707716
[alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
708-
[tensor(dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None))],
717+
[
718+
tensor(
719+
dtype=dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None)
720+
)
721+
],
709722
)
710723
return r
711724

@@ -1142,7 +1155,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
11421155
"""
11431156
assert b.type.ndim == 2
11441157
return Apply(
1145-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1158+
self,
1159+
[a_data, a_indices, a_indptr, b],
1160+
[tensor(dtype=b.dtype, shape=(None,))],
11461161
)
11471162

11481163
def c_code_cache_version(self):
@@ -1280,7 +1295,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
12801295
"""
12811296
assert b.type.ndim == 2
12821297
return Apply(
1283-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1298+
self,
1299+
[a_data, a_indices, a_indptr, b],
1300+
[tensor(dtype=b.dtype, shape=(None,))],
12841301
)
12851302

12861303
def c_code_cache_version(self):
@@ -1470,7 +1487,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
14701487
"""
14711488
assert b.type.ndim == 1
14721489
return Apply(
1473-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1490+
self,
1491+
[a_data, a_indices, a_indptr, b],
1492+
[tensor(dtype=b.dtype, shape=(None,))],
14741493
)
14751494

14761495
def c_code_cache_version(self):
@@ -1642,7 +1661,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
16421661
assert a_indptr.type.ndim == 1
16431662
assert b.type.ndim == 1
16441663
return Apply(
1645-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1664+
self,
1665+
[a_data, a_indices, a_indptr, b],
1666+
[tensor(dtype=b.dtype, shape=(None,))],
16461667
)
16471668

16481669
def c_code_cache_version(self):

pytensor/tensor/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2882,7 +2882,7 @@ def make_node(self, start, stop, step):
28822882
assert step.ndim == 0
28832883

28842884
inputs = [start, stop, step]
2885-
outputs = [tensor(self.dtype, shape=(None,))]
2885+
outputs = [tensor(dtype=self.dtype, shape=(None,))]
28862886

28872887
return Apply(self, inputs, outputs)
28882888

pytensor/tensor/blas.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ def make_node(self, x, y):
16801680
raise TypeError(y)
16811681
if y.type.dtype != x.type.dtype:
16821682
raise TypeError("dtype mismatch to Dot22")
1683-
outputs = [tensor(x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
1683+
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
16841684
return Apply(self, [x, y], outputs)
16851685

16861686
def perform(self, node, inp, out):
@@ -1985,7 +1985,7 @@ def make_node(self, x, y, a):
19851985
raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
19861986

19871987
sz = (x.type.shape[0], y.type.shape[1])
1988-
outputs = [tensor(x.type.dtype, shape=sz)]
1988+
outputs = [tensor(dtype=x.type.dtype, shape=sz)]
19891989
return Apply(self, [x, y, a], outputs)
19901990

19911991
def perform(self, node, inp, out):
@@ -2221,7 +2221,7 @@ def make_node(self, *inputs):
22212221
+ inputs[1].type.shape[2:]
22222222
)
22232223
out_shape = tuple(1 if s == 1 else None for s in out_shape)
2224-
return Apply(self, upcasted_inputs, [tensor(dtype, shape=out_shape)])
2224+
return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)])
22252225

22262226
def perform(self, node, inp, out):
22272227
x, y = inp

pytensor/tensor/io.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, dtype, shape, mmap_mode=None):
3636
def make_node(self, path):
3737
if isinstance(path, str):
3838
path = Constant(Generic(), path)
39-
return Apply(self, [path], [tensor(self.dtype, shape=self.shape)])
39+
return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
4040

4141
def perform(self, node, inp, out):
4242
path = inp[0]
@@ -135,7 +135,7 @@ def make_node(self):
135135
[],
136136
[
137137
Variable(Generic(), None),
138-
tensor(self.dtype, shape=self.static_shape),
138+
tensor(dtype=self.dtype, shape=self.static_shape),
139139
],
140140
)
141141

@@ -180,7 +180,7 @@ def make_node(self, request, data):
180180
return Apply(
181181
self,
182182
[request, data],
183-
[tensor(data.dtype, shape=data.type.shape)],
183+
[tensor(dtype=data.dtype, shape=data.type.shape)],
184184
)
185185

186186
def perform(self, node, inp, out):

pytensor/tensor/math.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def make_node(self, x):
152152
if i not in all_axes
153153
)
154154
outputs = [
155-
tensor(x.type.dtype, shape=out_shape, name="max"),
156-
tensor("int64", shape=out_shape, name="argmax"),
155+
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
156+
tensor(dtype="int64", shape=out_shape, name="argmax"),
157157
]
158158
return Apply(self, inputs, outputs)
159159

@@ -370,7 +370,7 @@ def make_node(self, x, axis=None):
370370
# We keep the original broadcastable flags for dimensions on which
371371
# we do not perform the argmax.
372372
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
373-
outputs = [tensor("int64", shape=out_shape, name="argmax")]
373+
outputs = [tensor(dtype="int64", shape=out_shape, name="argmax")]
374374
return Apply(self, inputs, outputs)
375375

376376
def prepare_node(self, node, storage_map, compute_map, impl):
@@ -1922,7 +1922,7 @@ def make_node(self, *inputs):
19221922
sz = sx[:-1]
19231923

19241924
i_dtypes = [input.type.dtype for input in inputs]
1925-
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)]
1925+
outputs = [tensor(dtype=aes.upcast(*i_dtypes), shape=sz)]
19261926
return Apply(self, inputs, outputs)
19271927

19281928
def perform(self, node, inp, out):

pytensor/tensor/shape.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def make_node(self, x, shp):
641641
except NotScalarConstantError:
642642
pass
643643

644-
return Apply(self, [x, shp], [tensor(x.type.dtype, shape=out_shape)])
644+
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
645645

646646
def perform(self, node, inp, out_, params):
647647
x, shp = inp

0 commit comments

Comments
 (0)