Skip to content

Commit 5023274

Browse files
Merge branch 'main' into numba-cholesky
2 parents 2761406 + c5b96d9 commit 5023274

25 files changed

+688
-77
lines changed

pytensor/graph/basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,15 +1439,16 @@ def io_toposort(
14391439
order = []
14401440
while todo:
14411441
cur = todo.pop()
1442-
# We suppose that all outputs are always computed
1443-
if cur.outputs[0] in computed:
1442+
if all(out in computed for out in cur.outputs):
14441443
continue
14451444
if all(i in computed or i.owner is None for i in cur.inputs):
14461445
computed.update(cur.outputs)
14471446
order.append(cur)
14481447
else:
14491448
todo.append(cur)
1450-
todo.extend(i.owner for i in cur.inputs if i.owner)
1449+
todo.extend(
1450+
i.owner for i in cur.inputs if (i.owner and i not in computed)
1451+
)
14511452
return order
14521453

14531454
compute_deps = None

pytensor/graph/replace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ def vectorize_graph(
306306
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
307307
vect_node = vectorize_node(node, *vect_inputs)
308308
for output, vect_output in zip(node.outputs, vect_node.outputs):
309+
if output in vect_vars:
310+
# This can happen when some outputs of a multi-output node are given a replacement,
311+
# while some of the remaining outputs are still needed in the graph.
312+
# We make sure we don't overwrite the provided replacement with the newly vectorized output
313+
continue
309314
vect_vars[output] = vect_output
310315

311316
seq_vect_outputs = [vect_vars[out] for out in seq_outputs]

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
4-
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
4+
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular
55

66

77
@jax_funcify.register(Cholesky)
@@ -45,3 +45,11 @@ def solve_triangular(A, b):
4545
)
4646

4747
return solve_triangular
48+
49+
50+
@jax_funcify.register(BlockDiagonal)
51+
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
52+
def block_diag(*inputs):
53+
return jax.scipy.linalg.block_diag(*inputs)
54+
55+
return block_diag

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pytensor.link.numba.dispatch import basic as numba_basic
1111
from pytensor.link.numba.dispatch.basic import numba_funcify
12-
from pytensor.tensor.slinalg import Cholesky, SolveTriangular
12+
from pytensor.tensor.slinalg import Cholesky, BlockDiagonal, SolveTriangular
1313

1414

1515
_PTR = ctypes.POINTER
@@ -299,7 +299,6 @@ def solve_triangular(a, b):
299299

300300
return solve_triangular
301301

302-
303302
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
304303
return linalg.cholesky(
305304
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
@@ -357,3 +356,24 @@ def nb_cholesky(a):
357356
return res
358357

359358
return nb_cholesky
359+
360+
@numba_funcify.register(BlockDiagonal)
361+
def numba_funcify_BlockDiagonal(op, node, **kwargs):
362+
dtype = node.outputs[0].dtype
363+
364+
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
365+
@numba_basic.numba_njit(inline="never")
366+
def block_diag(*arrs):
367+
shapes = np.array([a.shape for a in arrs], dtype="int")
368+
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
369+
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
370+
371+
r, c = 0, 0
372+
for arr, shape in zip(arrs, shapes):
373+
rr, cc = shape
374+
out[r : r + rr, c : c + cc] = arr
375+
r += rr
376+
c += cc
377+
return out
378+
379+
return block_diag

pytensor/sparse/basic.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TODO: Automatic methods for determining best sparse format?
88
99
"""
10+
from typing import Literal
1011
from warnings import warn
1112

1213
import numpy as np
@@ -47,6 +48,7 @@
4748
trunc,
4849
)
4950
from pytensor.tensor.shape import shape, specify_broadcastable
51+
from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype
5052
from pytensor.tensor.type import TensorType
5153
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
5254
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
@@ -60,7 +62,6 @@
6062

6163
sparse_formats = ["csc", "csr"]
6264

63-
6465
"""
6566
Types of sparse matrices to use for testing.
6667
@@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
183184

184185
as_sparse = as_sparse_variable
185186

186-
187187
as_sparse_or_tensor_variable = as_symbolic
188188

189189

@@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
18001800
return r
18011801

18021802
def __str__(self):
1803-
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
1803+
return f"{self.__class__.__name__}{{axis={self.axis}}}"
18041804

18051805

18061806
def sp_sum(x, axis=None, sparse_grad=False):
@@ -2775,19 +2775,14 @@ def comparison(self, x, y):
27752775

27762776
greater_equal_s_d = GreaterEqualSD()
27772777

2778-
27792778
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
27802779

2781-
27822780
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
27832781

2784-
27852782
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
27862783

2787-
27882784
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
27892785

2790-
27912786
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
27922787

27932788
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
@@ -2992,7 +2987,7 @@ def __str__(self):
29922987
l = []
29932988
if self.inplace:
29942989
l.append("inplace")
2995-
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
2990+
return f"{self.__class__.__name__}{{{', '.join(l)}}}"
29962991

29972992
def make_node(self, x):
29982993
"""
@@ -3291,6 +3286,7 @@ class TrueDot(Op):
32913286
# Simplify code by splitting into DotSS and DotSD.
32923287

32933288
__props__ = ()
3289+
32943290
# The grad_preserves_dense attribute doesn't change the
32953291
# execution behavior. To let the optimizer merge nodes with
32963292
# different values of this attribute we shouldn't compare it
@@ -4260,3 +4256,85 @@ def grad(self, inputs, grads):
42604256

42614257

42624258
construct_sparse_from_list = ConstructSparseFromList()
4259+
4260+
4261+
class SparseBlockDiagonal(BaseBlockDiagonal):
4262+
__props__ = (
4263+
"n_inputs",
4264+
"format",
4265+
)
4266+
4267+
def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"):
4268+
super().__init__(n_inputs)
4269+
self.format = format
4270+
4271+
def make_node(self, *matrices):
4272+
matrices = self._validate_and_prepare_inputs(
4273+
matrices, as_sparse_or_tensor_variable
4274+
)
4275+
dtype = _largest_common_dtype(matrices)
4276+
out_type = matrix(format=self.format, dtype=dtype)
4277+
4278+
return Apply(self, matrices, [out_type])
4279+
4280+
def perform(self, node, inputs, output_storage, params=None):
4281+
dtype = node.outputs[0].type.dtype
4282+
output_storage[0][0] = scipy.sparse.block_diag(
4283+
inputs, format=self.format
4284+
).astype(dtype)
4285+
4286+
4287+
def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"):
4288+
r"""
4289+
Construct a block diagonal matrix from a sequence of input matrices.
4290+
4291+
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
4292+
4293+
[[A, 0, 0],
4294+
[0, B, 0],
4295+
[0, 0, C]]
4296+
4297+
Parameters
4298+
----------
4299+
A, B, C ... : tensors
4300+
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
4301+
inputs should have at least 2 dimensins.
4302+
4303+
Note that the input matrices need not be sparse themselves, and will be automatically converted to the
4304+
requested format if they are not.
4305+
4306+
format: str, optional
4307+
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
4308+
4309+
Returns
4310+
-------
4311+
out: sparse matrix tensor
4312+
Symbolic sparse matrix in the specified format.
4313+
4314+
Examples
4315+
--------
4316+
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
4317+
4318+
..code-block:: python
4319+
import numpy as np
4320+
from pytensor.sparse import block_diag
4321+
from scipy.sparse import csr_matrix
4322+
4323+
A = csr_matrix([[1, 2], [3, 4]])
4324+
B = csr_matrix([[5, 6], [7, 8]])
4325+
result_sparse = block_diag(A, B, format='csr', name='X')
4326+
4327+
print(result_sparse)
4328+
>>> SparseVariable{csr,int32}
4329+
4330+
print(result_sparse.toarray().eval())
4331+
>>> array([[1, 2, 0, 0],
4332+
>>> [3, 4, 0, 0],
4333+
>>> [0, 0, 5, 6],
4334+
>>> [0, 0, 7, 8]])
4335+
"""
4336+
if len(matrices) == 1:
4337+
return matrices
4338+
4339+
_sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format)
4340+
return _sparse_block_diagonal(*matrices)

pytensor/tensor/basic.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@
4343
get_vector_length,
4444
)
4545
from pytensor.tensor.blockwise import Blockwise
46-
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
46+
from pytensor.tensor.elemwise import (
47+
DimShuffle,
48+
Elemwise,
49+
get_normalized_batch_axes,
50+
scalar_elemwise,
51+
)
4752
from pytensor.tensor.exceptions import NotScalarConstantError
4853
from pytensor.tensor.shape import (
4954
Shape,
@@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
36143619

36153620

36163621
@_vectorize_node.register(ExtractDiag)
3617-
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
3618-
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
3622+
def vectorize_extract_diag(op: ExtractDiag, node, batch_x):
3623+
core_ndim = node.inputs[0].type.ndim
3624+
batch_ndim = batch_x.type.ndim - core_ndim
3625+
batch_axis1, batch_axis2 = get_normalized_batch_axes(
3626+
(op.axis1, op.axis2), core_ndim, batch_ndim
3627+
)
3628+
36193629
return diagonal(
3620-
batched_x,
3630+
batch_x,
36213631
offset=op.offset,
3622-
axis1=op.axis1 + batched_ndims,
3623-
axis2=op.axis2 + batched_ndims,
3632+
axis1=batch_axis1,
3633+
axis2=batch_axis2,
36243634
).owner
36253635

36263636

@@ -4269,6 +4279,25 @@ def take_along_axis(arr, indices, axis=0):
42694279
return arr[_make_along_axis_idx(arr.shape, indices, axis)]
42704280

42714281

4282+
def ix_(*args):
4283+
"""
4284+
PyTensor np.ix_ analog
4285+
4286+
See numpy.lib.index_tricks.ix_ for reference
4287+
"""
4288+
out = []
4289+
nd = len(args)
4290+
for k, new in enumerate(args):
4291+
if new is None:
4292+
out.append(slice(None))
4293+
new = as_tensor(new)
4294+
if new.ndim != 1:
4295+
raise ValueError("Cross index must be 1 dimensional")
4296+
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
4297+
out.append(new)
4298+
return tuple(out)
4299+
4300+
42724301
__all__ = [
42734302
"take_along_axis",
42744303
"expand_dims",

pytensor/tensor/blas.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,20 @@
144144
# If check_init_y() == True we need to initialize y when beta == 0.
145145
def check_init_y():
146146
if check_init_y._result is None:
147-
if not have_fblas:
147+
if not have_fblas: # pragma: no cover
148148
check_init_y._result = False
149-
150-
y = float("NaN") * np.ones((2,))
151-
x = np.ones((2,))
152-
A = np.ones((2, 2))
153-
gemv = _blas_gemv_fns[y.dtype]
154-
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
155-
check_init_y._result = np.isnan(y).any()
149+
else:
150+
y = float("NaN") * np.ones((2,))
151+
x = np.ones((2,))
152+
A = np.ones((2, 2))
153+
gemv = _blas_gemv_fns[y.dtype]
154+
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
155+
check_init_y._result = np.isnan(y).any()
156156

157157
return check_init_y._result
158158

159159

160-
check_init_y._result = None
160+
check_init_y._result = None # type: ignore
161161

162162

163163
class Gemv(Op):

pytensor/tensor/blas_scipy.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,13 @@
1919

2020

2121
class ScipyGer(Ger):
22-
def prepare_node(self, node, storage_map, compute_map, impl):
23-
if impl == "py":
24-
node.tag.local_ger = _blas_ger_fns[np.dtype(node.inputs[0].type.dtype)]
25-
2622
def perform(self, node, inputs, output_storage):
2723
cA, calpha, cx, cy = inputs
2824
(cZ,) = output_storage
2925
# N.B. some versions of scipy (e.g. mine) don't actually work
3026
# in-place on a, even when I tell it to.
3127
A = cA
32-
local_ger = node.tag.local_ger
28+
local_ger = _blas_ger_fns[cA.dtype]
3329
if A.size == 0:
3430
# We don't have to compute anything, A is empty.
3531
# We need this special case because Numpy considers it

0 commit comments

Comments
 (0)