Skip to content

Commit 849c3b8

Browse files
committed
Use boolean __props__ in SVD
Fixes failure in more recent versions of jaxlib
1 parent 931297f commit 849c3b8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pytensor/tensor/nlinalg.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,9 @@ class SVD(Op):
541541
# See doc in the docstring of the function just after this class.
542542
__props__ = ("full_matrices", "compute_uv")
543543

544-
def __init__(self, full_matrices=True, compute_uv=True):
545-
self.full_matrices = full_matrices
546-
self.compute_uv = compute_uv
544+
def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
545+
self.full_matrices = bool(full_matrices)
546+
self.compute_uv = bool(compute_uv)
547547

548548
def make_node(self, x):
549549
x = as_tensor_variable(x)
@@ -584,7 +584,7 @@ def infer_shape(self, fgraph, node, shapes):
584584
return [s_shape]
585585

586586

587-
def svd(a, full_matrices=1, compute_uv=1):
587+
def svd(a, full_matrices: bool = True, compute_uv: bool = True):
588588
"""
589589
This function performs the SVD on CPU.
590590

0 commit comments

Comments
 (0)