Skip to content

Commit 959a499

Browse files
author
Mohit Kumar
committed
Add test for 'verify_grad' without explicit RNG
1 parent 33a4d48 commit 959a499

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

pytensor/gradient.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1770,14 +1770,9 @@ def verify_grad(
17701770
if rel_tol is None:
17711771
rel_tol = max(_type_tol[str(p.dtype)] for p in pt)
17721772

1773+
# Initialize RNG if not provided
17731774
if rng is None:
1774-
raise TypeError(
1775-
"rng should be a valid instance of "
1776-
"numpy.random.RandomState. You may "
1777-
"want to use tests.unittest"
1778-
"_tools.verify_grad instead of "
1779-
"pytensor.gradient.verify_grad."
1780-
)
1775+
rng = np.random.default_rng()
17811776

17821777
# We allow input downcast in `function`, because `numeric_grad` works in
17831778
# the most precise dtype used among the inputs, so we may need to cast

tests/test_gradient.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from scipy.optimize import rosen_hess_prod
44

55
import pytensor
6+
import pytensor.tensor as pt
67
import pytensor.tensor.basic as ptb
78
from pytensor.configdefaults import config
89
from pytensor.gradient import (
@@ -602,6 +603,10 @@ def test_grad_constant(self):
602603
+ str(g_one)
603604
)
604605

606+
def test_verify_grad_no_rng(self):
607+
"""Test `verify_grad` works without requiring an explicit RNG."""
608+
utt.verify_grad(pt.log, [2.0])
609+
605610

606611
def test_known_grads():
607612
# Tests that the grad method with no known_grads

0 commit comments

Comments
 (0)