Skip to content

Commit 117d011

Browse files
committed
CholeskySolve inherits from BaseSolve
1 parent ca969f2 commit 117d011

File tree

2 files changed

+48
-98
lines changed

2 files changed

+48
-98
lines changed

pytensor/tensor/slinalg.py

+47-97
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Cholesky(Op):
4747

4848
__props__ = ("lower", "destructive", "on_error")
4949

50-
def __init__(self, lower=True, on_error="raise"):
50+
def __init__(self, *, lower=True, on_error="raise"):
5151
self.lower = lower
5252
self.destructive = False
5353
if on_error not in ("raise", "nan"):
@@ -125,77 +125,8 @@ def conjugate_solve_triangular(outer, inner):
125125
return [grad]
126126

127127

128-
cholesky = Cholesky()
129-
130-
131-
class CholeskySolve(Op):
132-
__props__ = ("lower", "check_finite")
133-
134-
def __init__(
135-
self,
136-
lower=True,
137-
check_finite=True,
138-
):
139-
self.lower = lower
140-
self.check_finite = check_finite
141-
142-
def __repr__(self):
143-
return "CholeskySolve{%s}" % str(self._props())
144-
145-
def make_node(self, C, b):
146-
C = as_tensor_variable(C)
147-
b = as_tensor_variable(b)
148-
assert C.ndim == 2
149-
assert b.ndim in (1, 2)
150-
151-
# infer dtype by solving the most simple
152-
# case with (1, 1) matrices
153-
o_dtype = scipy.linalg.solve(
154-
np.eye(1).astype(C.dtype), np.eye(1).astype(b.dtype)
155-
).dtype
156-
x = tensor(dtype=o_dtype, shape=b.type.shape)
157-
return Apply(self, [C, b], [x])
158-
159-
def perform(self, node, inputs, output_storage):
160-
C, b = inputs
161-
rval = scipy.linalg.cho_solve(
162-
(C, self.lower),
163-
b,
164-
check_finite=self.check_finite,
165-
)
166-
167-
output_storage[0][0] = rval
168-
169-
def infer_shape(self, fgraph, node, shapes):
170-
Cshape, Bshape = shapes
171-
rows = Cshape[1]
172-
if len(Bshape) == 1: # b is a Vector
173-
return [(rows,)]
174-
else:
175-
cols = Bshape[1] # b is a Matrix
176-
return [(rows, cols)]
177-
178-
179-
cho_solve = CholeskySolve()
180-
181-
182-
def cho_solve(c_and_lower, b, check_finite=True):
183-
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
184-
185-
Parameters
186-
----------
187-
(c, lower) : tuple, (array, bool)
188-
Cholesky factorization of a, as given by cho_factor
189-
b : array
190-
Right-hand side
191-
check_finite : bool, optional
192-
Whether to check that the input matrices contain only finite numbers.
193-
Disabling may give a performance gain, but may result in problems
194-
(crashes, non-termination) if the inputs do contain infinities or NaNs.
195-
"""
196-
197-
A, lower = c_and_lower
198-
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
128+
def cholesky(x, lower=True, on_error="raise"):
129+
return Cholesky(lower=lower, on_error=on_error)(x)
199130

200131

201132
class SolveBase(Op):
@@ -208,6 +139,7 @@ class SolveBase(Op):
208139

209140
def __init__(
210141
self,
142+
*,
211143
lower=False,
212144
check_finite=True,
213145
):
@@ -274,28 +206,56 @@ def L_op(self, inputs, outputs, output_gradients):
274206

275207
return [A_bar, b_bar]
276208

277-
def __repr__(self):
278-
return f"{type(self).__name__}{self._props()}"
209+
210+
class CholeskySolve(SolveBase):
211+
def __init__(self, **kwargs):
212+
kwargs.setdefault("lower", True)
213+
super().__init__(**kwargs)
214+
215+
def perform(self, node, inputs, output_storage):
216+
C, b = inputs
217+
rval = scipy.linalg.cho_solve(
218+
(C, self.lower),
219+
b,
220+
check_finite=self.check_finite,
221+
)
222+
223+
output_storage[0][0] = rval
224+
225+
def L_op(self, *args, **kwargs):
226+
raise NotImplementedError()
227+
228+
229+
def cho_solve(c_and_lower, b, *, check_finite=True):
230+
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
231+
232+
Parameters
233+
----------
234+
(c, lower) : tuple, (array, bool)
235+
Cholesky factorization of a, as given by cho_factor
236+
b : array
237+
Right-hand side
238+
check_finite : bool, optional
239+
Whether to check that the input matrices contain only finite numbers.
240+
Disabling may give a performance gain, but may result in problems
241+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
242+
"""
243+
A, lower = c_and_lower
244+
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
279245

280246

281247
class SolveTriangular(SolveBase):
282248
"""Solve a system of linear equations."""
283249

284250
__props__ = (
285-
"lower",
286251
"trans",
287252
"unit_diagonal",
253+
"lower",
288254
"check_finite",
289255
)
290256

291-
def __init__(
292-
self,
293-
trans=0,
294-
lower=False,
295-
unit_diagonal=False,
296-
check_finite=True,
297-
):
298-
super().__init__(lower=lower, check_finite=check_finite)
257+
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
258+
super().__init__(**kwargs)
299259
self.trans = trans
300260
self.unit_diagonal = unit_diagonal
301261

@@ -321,12 +281,10 @@ def L_op(self, inputs, outputs, output_gradients):
321281
return res
322282

323283

324-
solvetriangular = SolveTriangular()
325-
326-
327284
def solve_triangular(
328285
a: TensorVariable,
329286
b: TensorVariable,
287+
*,
330288
trans: Union[int, str] = 0,
331289
lower: bool = False,
332290
unit_diagonal: bool = False,
@@ -374,16 +332,11 @@ class Solve(SolveBase):
374332
"check_finite",
375333
)
376334

377-
def __init__(
378-
self,
379-
assume_a="gen",
380-
lower=False,
381-
check_finite=True,
382-
):
335+
def __init__(self, *, assume_a="gen", **kwargs):
383336
if assume_a not in ("gen", "sym", "her", "pos"):
384337
raise ValueError(f"{assume_a} is not a recognized matrix structure")
385338

386-
super().__init__(lower=lower, check_finite=check_finite)
339+
super().__init__(**kwargs)
387340
self.assume_a = assume_a
388341

389342
def perform(self, node, inputs, outputs):
@@ -397,10 +350,7 @@ def perform(self, node, inputs, outputs):
397350
)
398351

399352

400-
solve = Solve()
401-
402-
403-
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
353+
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
404354
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
405355
406356
If the data matrix is known to be a particular type then supplying the

tests/tensor/test_slinalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def setup_method(self):
360360
super().setup_method()
361361

362362
def test_repr(self):
363-
assert repr(CholeskySolve()) == "CholeskySolve{(True, True)}"
363+
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)"
364364

365365
def test_infer_shape(self):
366366
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)