Skip to content

Commit 0103154

Browse files
committed
Use Blockwise by default
1 parent 4a72735 commit 0103154

File tree

7 files changed

+82
-30
lines changed

7 files changed

+82
-30
lines changed

pytensor/tensor/blockwise.py

-7
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,6 @@ def make_node(self, *inputs):
171171
inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig)
172172
)
173173

174-
# Don't pollute the graph with useless BlockWise
175-
# TODO: Do we want to do this? Or leave it as a Blockwise and later have a rewrite that removes useless casse
176-
# A reason to not eagerly avoid Blockwise is that we could make all rewrites track the Blockwise version,
177-
# instead of having to track both or only the more restricted core case.
178-
if not batch_ndims:
179-
return self.core_op.make_node(*inputs)
180-
181174
batched_inputs = []
182175
batch_shapes = []
183176
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):

pytensor/tensor/nlinalg.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.tensor import basic as at
1111
from pytensor.tensor import math as tm
1212
from pytensor.tensor.basic import as_tensor_variable, extract_diag
13+
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
1415

1516

@@ -76,7 +77,7 @@ def pinv(x, hermitian=False):
7677
solve op.
7778
7879
"""
79-
return MatrixPinv(hermitian=hermitian)(x)
80+
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
8081

8182

8283
class MatrixInverse(Op):
@@ -153,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
153154
return shapes
154155

155156

156-
inv = matrix_inverse = MatrixInverse()
157+
inv = matrix_inverse = Blockwise(MatrixInverse())
157158

158159

159160
def matrix_dot(*args):
@@ -214,7 +215,7 @@ def __str__(self):
214215
return "Det"
215216

216217

217-
det = Det()
218+
det = Blockwise(Det())
218219

219220

220221
class SLogDet(Op):
@@ -248,7 +249,7 @@ def __str__(self):
248249
return "SLogDet"
249250

250251

251-
slogdet = SLogDet()
252+
slogdet = Blockwise(SLogDet())
252253

253254

254255
class Eig(Op):
@@ -278,7 +279,7 @@ def infer_shape(self, fgraph, node, shapes):
278279
return [(n,), (n, n)]
279280

280281

281-
eig = Eig()
282+
eig = Blockwise(Eig())
282283

283284

284285
class Eigh(Eig):

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
44
import pytensor.tensor.rewriting.blas_scipy
5+
import pytensor.tensor.rewriting.blockwise
56
import pytensor.tensor.rewriting.elemwise
67
import pytensor.tensor.rewriting.extra_ops
78

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from pytensor.compile.mode import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
4+
from pytensor.tensor.blockwise import Blockwise, vectorize_node
5+
from pytensor.tensor.rewriting.basic import register_useless
6+
7+
8+
@register_useless("fast_compile")
9+
@node_rewriter([Blockwise])
10+
def local_useless_blockwise(fgraph, node):
11+
# If there is a dispatch implementation that does not require Blockwise, use that instead.
12+
# This means a user created a Blockwise manually when there was no need.
13+
op = node.op
14+
inputs = node.inputs
15+
dummy_core_node = op._create_dummy_core_node(node.inputs)
16+
vect_node = vectorize_node(dummy_core_node, *inputs)
17+
if not isinstance(vect_node.op, Blockwise):
18+
return copy_stack_trace(node.outputs, vect_node.outputs)
19+
20+
21+
@node_rewriter([Blockwise])
22+
def local_useless_unbatched_blockwise(fgraph, node):
23+
"""Remove Blockwise that don't have any batched dims."""
24+
op = node.op
25+
inputs = node.inputs
26+
27+
if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0:
28+
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)
29+
30+
31+
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
32+
optdb.register(
33+
"local_useless_unbatched_blockwise",
34+
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
35+
"fast_run",
36+
"fast_compile",
37+
"blockwise",
38+
position=49,
39+
)

pytensor/tensor/utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Callable, Optional
2-
31
import numpy as np
42

53
import pytensor
@@ -132,4 +130,4 @@ def import_func_from_string(func_string: str): # -> Optional[Callable]:
132130
except AttributeError:
133131
module = None
134132
break
135-
return module
133+
return module
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pytensor import function
2+
from pytensor.scalar import log as scalar_log
3+
from pytensor.tensor import matrix, tensor3
4+
from pytensor.tensor.blockwise import Blockwise
5+
from pytensor.tensor.elemwise import Elemwise
6+
from pytensor.tensor.nlinalg import MatrixPinv, pinv
7+
8+
9+
def test_useless_blockwise_of_elemwise():
10+
x = matrix("x")
11+
out = Blockwise(Elemwise(scalar_log), signature="()->()")(x)
12+
13+
assert isinstance(out.owner.op, Blockwise)
14+
assert isinstance(out.owner.op.core_op, Elemwise)
15+
16+
fn = function([x], out, mode="FAST_COMPILE")
17+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Elemwise)
18+
19+
20+
def test_useless_unbatched_blockwise():
21+
x = matrix("x")
22+
out = pinv(x)
23+
24+
assert isinstance(out.owner.op, Blockwise)
25+
assert isinstance(out.owner.op.core_op, MatrixPinv)
26+
27+
fn = function([x], out, mode="FAST_COMPILE")
28+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv)
29+
30+
# Test that it's not removed when there are batched dims
31+
x = tensor3("x")
32+
out = pinv(x)
33+
fn = function([x], out, mode="FAST_COMPILE")
34+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
35+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)

tests/tensor/test_blockwise.py

-15
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,6 @@ def test_vectorize_node():
7777
assert new_vect_node.inputs[0] is tns4
7878

7979

80-
def test_useless_blockwise():
81-
cop = MatrixInverse()
82-
bop = Blockwise(cop, signature=("(m, m) -> (m, m)"))
83-
84-
inp = tensor(shape=(None, None, None))
85-
out = bop(inp)
86-
assert out.owner.op is bop
87-
assert out.owner.inputs[0] is inp
88-
89-
inp = tensor(shape=(None, None))
90-
out = bop(inp)
91-
assert out.owner.op is cop
92-
assert out.owner.inputs[0] is inp
93-
94-
9580
class TestOp(Op):
9681
def make_node(self, *inputs):
9782
return Apply(self, inputs, [i.type() for i in inputs])

0 commit comments

Comments
 (0)