Skip to content

Commit 4be81f5

Browse files
committed
CholeskySolve inherits from BaseSolve
1 parent 3e96e2f commit 4be81f5

File tree

3 files changed

+51
-95
lines changed

3 files changed

+51
-95
lines changed

pytensor/tensor/slinalg.py

+47-91
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Cholesky(Op):
4949

5050
__props__ = ("lower", "destructive", "on_error")
5151

52-
def __init__(self, lower=True, on_error="raise"):
52+
def __init__(self, *, lower=True, on_error="raise"):
5353
self.lower = lower
5454
self.destructive = False
5555
if on_error not in ("raise", "nan"):
@@ -127,77 +127,8 @@ def conjugate_solve_triangular(outer, inner):
127127
return [grad]
128128

129129

130-
cholesky = Cholesky()
131-
132-
133-
class CholeskySolve(Op):
134-
__props__ = ("lower", "check_finite")
135-
136-
def __init__(
137-
self,
138-
lower=True,
139-
check_finite=True,
140-
):
141-
self.lower = lower
142-
self.check_finite = check_finite
143-
144-
def __repr__(self):
145-
return "CholeskySolve{%s}" % str(self._props())
146-
147-
def make_node(self, C, b):
148-
C = as_tensor_variable(C)
149-
b = as_tensor_variable(b)
150-
assert C.ndim == 2
151-
assert b.ndim in (1, 2)
152-
153-
# infer dtype by solving the most simple
154-
# case with (1, 1) matrices
155-
o_dtype = scipy.linalg.solve(
156-
np.eye(1).astype(C.dtype), np.eye(1).astype(b.dtype)
157-
).dtype
158-
x = tensor(dtype=o_dtype, shape=b.type.shape)
159-
return Apply(self, [C, b], [x])
160-
161-
def perform(self, node, inputs, output_storage):
162-
C, b = inputs
163-
rval = scipy.linalg.cho_solve(
164-
(C, self.lower),
165-
b,
166-
check_finite=self.check_finite,
167-
)
168-
169-
output_storage[0][0] = rval
170-
171-
def infer_shape(self, fgraph, node, shapes):
172-
Cshape, Bshape = shapes
173-
rows = Cshape[1]
174-
if len(Bshape) == 1: # b is a Vector
175-
return [(rows,)]
176-
else:
177-
cols = Bshape[1] # b is a Matrix
178-
return [(rows, cols)]
179-
180-
181-
cho_solve = CholeskySolve()
182-
183-
184-
def cho_solve(c_and_lower, b, check_finite=True):
185-
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
186-
187-
Parameters
188-
----------
189-
(c, lower) : tuple, (array, bool)
190-
Cholesky factorization of a, as given by cho_factor
191-
b : array
192-
Right-hand side
193-
check_finite : bool, optional
194-
Whether to check that the input matrices contain only finite numbers.
195-
Disabling may give a performance gain, but may result in problems
196-
(crashes, non-termination) if the inputs do contain infinities or NaNs.
197-
"""
198-
199-
A, lower = c_and_lower
200-
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
130+
def cholesky(x, lower=True, on_error="raise"):
131+
return Cholesky(lower=lower, on_error=on_error)(x)
201132

202133

203134
class SolveBase(Op):
@@ -210,6 +141,7 @@ class SolveBase(Op):
210141

211142
def __init__(
212143
self,
144+
*,
213145
lower=False,
214146
check_finite=True,
215147
):
@@ -276,28 +208,56 @@ def L_op(self, inputs, outputs, output_gradients):
276208

277209
return [A_bar, b_bar]
278210

279-
def __repr__(self):
280-
return f"{type(self).__name__}{self._props()}"
211+
212+
class CholeskySolve(SolveBase):
213+
def __init__(self, **kwargs):
214+
kwargs.setdefault("lower", True)
215+
super().__init__(**kwargs)
216+
217+
def perform(self, node, inputs, output_storage):
218+
C, b = inputs
219+
rval = scipy.linalg.cho_solve(
220+
(C, self.lower),
221+
b,
222+
check_finite=self.check_finite,
223+
)
224+
225+
output_storage[0][0] = rval
226+
227+
def L_op(self, *args, **kwargs):
228+
raise NotImplementedError()
229+
230+
231+
def cho_solve(c_and_lower, b, *, check_finite=True):
232+
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
233+
234+
Parameters
235+
----------
236+
(c, lower) : tuple, (array, bool)
237+
Cholesky factorization of a, as given by cho_factor
238+
b : array
239+
Right-hand side
240+
check_finite : bool, optional
241+
Whether to check that the input matrices contain only finite numbers.
242+
Disabling may give a performance gain, but may result in problems
243+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
244+
"""
245+
A, lower = c_and_lower
246+
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
281247

282248

283249
class SolveTriangular(SolveBase):
284250
"""Solve a system of linear equations."""
285251

286252
__props__ = (
287-
"lower",
288253
"trans",
289254
"unit_diagonal",
255+
"lower",
290256
"check_finite",
291257
)
292258

293-
def __init__(
294-
self,
295-
trans=0,
296-
lower=False,
297-
unit_diagonal=False,
298-
check_finite=True,
299-
):
300-
super().__init__(lower=lower, check_finite=check_finite)
259+
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
260+
super().__init__(**kwargs)
301261
self.trans = trans
302262
self.unit_diagonal = unit_diagonal
303263

@@ -326,6 +286,7 @@ def L_op(self, inputs, outputs, output_gradients):
326286
def solve_triangular(
327287
a: TensorVariable,
328288
b: TensorVariable,
289+
*,
329290
trans: Union[int, str] = 0,
330291
lower: bool = False,
331292
unit_diagonal: bool = False,
@@ -373,16 +334,11 @@ class Solve(SolveBase):
373334
"check_finite",
374335
)
375336

376-
def __init__(
377-
self,
378-
assume_a="gen",
379-
lower=False,
380-
check_finite=True,
381-
):
337+
def __init__(self, *, assume_a="gen", **kwargs):
382338
if assume_a not in ("gen", "sym", "her", "pos"):
383339
raise ValueError(f"{assume_a} is not a recognized matrix structure")
384340

385-
super().__init__(lower=lower, check_finite=check_finite)
341+
super().__init__(**kwargs)
386342
self.assume_a = assume_a
387343

388344
def perform(self, node, inputs, outputs):
@@ -396,7 +352,7 @@ def perform(self, node, inputs, outputs):
396352
)
397353

398354

399-
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
355+
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
400356
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
401357
402358
If the data matrix is known to be a particular type then supplying the

tests/link/numba/test_nlinalg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
],
4747
)
4848
def test_Cholesky(x, lower, exc):
49-
g = slinalg.Cholesky(lower)(x)
49+
g = slinalg.Cholesky(lower=lower)(x)
5050

5151
if isinstance(g, list):
5252
g_fg = FunctionGraph(outputs=g)
@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
9191
],
9292
)
9393
def test_Solve(A, x, lower, exc):
94-
g = slinalg.Solve(lower)(A, x)
94+
g = slinalg.Solve(lower=lower)(A, x)
9595

9696
if isinstance(g, list):
9797
g_fg = FunctionGraph(outputs=g)
@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
125125
],
126126
)
127127
def test_SolveTriangular(A, x, lower, exc):
128-
g = slinalg.SolveTriangular(lower)(A, x)
128+
g = slinalg.SolveTriangular(lower=lower)(A, x)
129129

130130
if isinstance(g, list):
131131
g_fg = FunctionGraph(outputs=g)

tests/tensor/test_slinalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def setup_method(self):
361361
super().setup_method()
362362

363363
def test_repr(self):
364-
assert repr(CholeskySolve()) == "CholeskySolve{(True, True)}"
364+
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)"
365365

366366
def test_infer_shape(self):
367367
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)