Skip to content

Commit 6660ea3

Browse files
committed
Use Blockwise by default
1 parent 16401bc commit 6660ea3

File tree

6 files changed

+84
-27
lines changed

6 files changed

+84
-27
lines changed

pytensor/tensor/blockwise.py

-7
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,6 @@ def make_node(self, *inputs):
178178
inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig)
179179
)
180180

181-
# Don't pollute the graph with useless BlockWise
182-
# TODO: Do we want to do this? Or leave it as a Blockwise and later have a rewrite that removes useless casse
183-
# A reason to not eagerly avoid Blockwise is that we could make all rewrites track the Blockwise version,
184-
# instead of having to track both or only the more restricted core case.
185-
if not batch_ndims:
186-
return self.core_op.make_node(*inputs)
187-
188181
batched_inputs = []
189182
batch_shapes = []
190183
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):

pytensor/tensor/nlinalg.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.tensor import basic as at
1111
from pytensor.tensor import math as tm
1212
from pytensor.tensor.basic import as_tensor_variable, extract_diag
13+
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
1415

1516

@@ -76,7 +77,7 @@ def pinv(x, hermitian=False):
7677
solve op.
7778
7879
"""
79-
return MatrixPinv(hermitian=hermitian)(x)
80+
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
8081

8182

8283
class MatrixInverse(Op):
@@ -153,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
153154
return shapes
154155

155156

156-
matrix_inverse = MatrixInverse()
157+
matrix_inverse = Blockwise(MatrixInverse())
157158
inv = matrix_inverse
158159

159160

@@ -215,7 +216,7 @@ def __str__(self):
215216
return "Det"
216217

217218

218-
det = Det()
219+
det = Blockwise(Det())
219220

220221

221222
class SLogDet(Op):
@@ -249,7 +250,7 @@ def __str__(self):
249250
return "SLogDet"
250251

251252

252-
slogdet = SLogDet()
253+
slogdet = Blockwise(SLogDet())
253254

254255

255256
class Eig(Op):
@@ -279,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
279280
return [(n,), (n, n)]
280281

281282

282-
eig = Eig()
283+
eig = Blockwise(Eig())
283284

284285

285286
class Eigh(Eig):

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
44
import pytensor.tensor.rewriting.blas_scipy
5+
import pytensor.tensor.rewriting.blockwise
56
import pytensor.tensor.rewriting.elemwise
67
import pytensor.tensor.rewriting.extra_ops
78

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from pytensor.compile.mode import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.rewriting.basic import out2in, copy_stack_trace
4+
from pytensor.tensor.blockwise import Blockwise, vectorize_node
5+
from pytensor.tensor.rewriting.basic import register_useless
6+
7+
8+
@register_useless("fast_compile")
9+
@node_rewriter([Blockwise])
10+
def local_useless_blockwise(fgraph, node):
11+
# If there is a dispatch implementation that does not require Blockwise, use that instead.
12+
# This means a user created a Blockwise manually when there was no need.
13+
op: Blockwise = node.op
14+
inputs = node.inputs
15+
dummy_core_node = op._create_dummy_core_node(node.inputs)
16+
vect_node = vectorize_node(dummy_core_node, *inputs)
17+
if not isinstance(vect_node.op, Blockwise):
18+
return copy_stack_trace(node.outputs, vect_node.outputs)
19+
20+
21+
@node_rewriter([Blockwise])
22+
def local_useless_unbatched_blockwise(fgraph, node):
23+
"""Remove Blockwise that don't have any batched dims."""
24+
op: Blockwise = node.op
25+
inputs = node.inputs
26+
27+
if max(
28+
inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)
29+
) == 0:
30+
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)
31+
32+
33+
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
34+
optdb.register(
35+
"local_useless_unbatched_blockwise",
36+
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
37+
"fast_run",
38+
"fast_compile",
39+
"blockwise",
40+
position=49,
41+
)
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pytensor import function
2+
from pytensor.scalar import log as scalar_log
3+
from pytensor.tensor import matrix, tensor3
4+
from pytensor.tensor.blockwise import Blockwise
5+
from pytensor.tensor.elemwise import Elemwise
6+
from pytensor.tensor.nlinalg import pinv, MatrixPinv
7+
8+
9+
def test_useless_blockwise_of_elemwise():
10+
x = matrix("x")
11+
out = Blockwise(Elemwise(scalar_log), signature="()->()")(x)
12+
13+
assert isinstance(out.owner.op, Blockwise)
14+
assert isinstance(out.owner.op.core_op, Elemwise)
15+
16+
fn = function([x], out, mode="FAST_COMPILE")
17+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Elemwise)
18+
19+
20+
def test_useless_unbatched_blockwise():
21+
x = matrix("x")
22+
out = pinv(x)
23+
24+
assert isinstance(out.owner.op, Blockwise)
25+
assert isinstance(out.owner.op.core_op, MatrixPinv)
26+
27+
fn = function([x], out, mode="FAST_COMPILE")
28+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv)
29+
30+
# Test that it's not removed when there are batched dims
31+
x = tensor3("x")
32+
out = pinv(x)
33+
fn = function([x], out, mode="FAST_COMPILE")
34+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
35+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
36+

tests/tensor/test_blockwise.py

-15
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,6 @@ def test_vectorize_node():
7777
assert new_vect_node.inputs[0] is tns4
7878

7979

80-
def test_useless_blockwise():
81-
cop = MatrixInverse()
82-
bop = Blockwise(cop, signature=("(m, m) -> (m, m)"))
83-
84-
inp = tensor(shape=(None, None, None))
85-
out = bop(inp)
86-
assert out.owner.op is bop
87-
assert out.owner.inputs[0] is inp
88-
89-
inp = tensor(shape=(None, None))
90-
out = bop(inp)
91-
assert out.owner.op is cop
92-
assert out.owner.inputs[0] is inp
93-
94-
9580
class TestOp(Op):
9681
def make_node(self, *inputs):
9782
return Apply(self, inputs, [i.type() for i in inputs])

0 commit comments

Comments
 (0)