File tree 2 files changed +7
-7
lines changed
2 files changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -1770,14 +1770,9 @@ def verify_grad(
1770
1770
if rel_tol is None :
1771
1771
rel_tol = max (_type_tol [str (p .dtype )] for p in pt )
1772
1772
1773
+ # Initialize RNG if not provided
1773
1774
if rng is None :
1774
- raise TypeError (
1775
- "rng should be a valid instance of "
1776
- "numpy.random.RandomState. You may "
1777
- "want to use tests.unittest"
1778
- "_tools.verify_grad instead of "
1779
- "pytensor.gradient.verify_grad."
1780
- )
1775
+ rng = np .random .default_rng ()
1781
1776
1782
1777
# We allow input downcast in `function`, because `numeric_grad` works in
1783
1778
# the most precise dtype used among the inputs, so we may need to cast
Original file line number Diff line number Diff line change 3
3
from scipy .optimize import rosen_hess_prod
4
4
5
5
import pytensor
6
+ import pytensor .tensor as pt
6
7
import pytensor .tensor .basic as ptb
7
8
from pytensor .configdefaults import config
8
9
from pytensor .gradient import (
@@ -602,6 +603,10 @@ def test_grad_constant(self):
602
603
+ str (g_one )
603
604
)
604
605
606
+ def test_verify_grad_no_rng (self ):
607
+ """Test `verify_grad` works without requiring an explicit RNG."""
608
+ utt .verify_grad (pt .log , [2.0 ])
609
+
605
610
606
611
def test_known_grads ():
607
612
# Tests that the grad method with no known_grads
You can’t perform that action at this time.
0 commit comments