Skip to content

Commit 2b64589

Browse files
committed
StandandardNormalRV is now just a helper function
1 parent c3dfc5a commit 2b64589

File tree

4 files changed

+25
-42
lines changed

4 files changed

+25
-42
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def sample_fn(rng, size, dtype, *parameters):
145145
@jax_sample_fn.register(ptr.LaplaceRV)
146146
@jax_sample_fn.register(ptr.LogisticRV)
147147
@jax_sample_fn.register(ptr.NormalRV)
148-
@jax_sample_fn.register(ptr.StandardNormalRV)
149148
def jax_sample_fn_loc_scale(op):
150149
"""JAX implementation of random variables in the loc-scale families.
151150

pytensor/tensor/random/basic.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -278,38 +278,24 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
278278
normal = NormalRV()
279279

280280

281-
class StandardNormalRV(NormalRV):
282-
r"""A standard normal continuous random variable.
281+
def standard_normal(*, size=None, rng=None, dtype=None):
282+
"""Draw samples from a standard normal distribution.
283283
284-
The probability density function for `standard_normal` is:
284+
Signature
285+
---------
285286
286-
.. math::
287+
`nil -> ()`
287288
288-
f(x) = \frac{1}{\sqrt{2 \pi}} e^{-\frac{x^2}{2}}
289+
Parameters
290+
----------
291+
size
292+
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
293+
independent, identically distributed random variables are
294+
returned. Default is `None` in which case a single random variable
295+
is returned.
289296
290297
"""
291-
292-
def __call__(self, size=None, **kwargs):
293-
"""Draw samples from a standard normal distribution.
294-
295-
Signature
296-
---------
297-
298-
`nil -> ()`
299-
300-
Parameters
301-
----------
302-
size
303-
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
304-
independent, identically distributed random variables are
305-
returned. Default is `None` in which case a single random variable
306-
is returned.
307-
308-
"""
309-
return super().__call__(loc=0.0, scale=1.0, size=size, **kwargs)
310-
311-
312-
standard_normal = StandardNormalRV()
298+
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype)
313299

314300

315301
class HalfNormalRV(ScipyRandomVariable):

pytensor/tensor/random/utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def __init__(
209209
if namespace is None:
210210
from pytensor.tensor.random import basic # pylint: disable=import-self
211211

212-
self.namespaces = [basic]
212+
self.namespaces = [(basic, set(basic.__all__))]
213213
else:
214-
self.namespaces = [namespace]
214+
self.namespaces = [(namespace, set(namespace.__all__))]
215215

216216
self.default_instance_seed = seed
217217
self.state_updates = []
@@ -226,22 +226,20 @@ def rng_ctor(seed):
226226

227227
def __getattr__(self, obj):
228228
ns_obj = next(
229-
(getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None
229+
(
230+
getattr(ns, obj)
231+
for ns, all_ in self.namespaces
232+
if obj in all_ and hasattr(ns, obj)
233+
),
234+
None,
230235
)
231236

232237
if ns_obj is None:
233238
raise AttributeError(f"No attribute {obj}.")
234239

235-
from pytensor.tensor.random.op import RandomVariable
236-
237-
if isinstance(ns_obj, RandomVariable):
238-
239-
@wraps(ns_obj)
240-
def meta_obj(*args, **kwargs):
241-
return self.gen(ns_obj, *args, **kwargs)
242-
243-
else:
244-
raise AttributeError(f"No attribute {obj}.")
240+
@wraps(ns_obj)
241+
def meta_obj(*args, **kwargs):
242+
return self.gen(ns_obj, *args, **kwargs)
245243

246244
setattr(self, obj, meta_obj)
247245
return getattr(self, obj)

tests/tensor/random/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_basics(self, rng_ctor):
114114
assert hasattr(random, "standard_normal")
115115

116116
with pytest.raises(AttributeError):
117-
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
117+
np_random = RandomStream(namespace=np.random, rng_ctor=rng_ctor)
118118
np_random.ndarray
119119

120120
fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates())

0 commit comments

Comments
 (0)