Skip to content

Commit d50efd5

Browse files
committed
Introduce signature instead of ndim_supp and ndims_params
1 parent 5f0c106 commit d50efd5

File tree

10 files changed

+306
-284
lines changed

10 files changed

+306
-284
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,19 @@ def sample_fn(rng, size, dtype, *parameters):
270270

271271

272272
@jax_sample_fn.register(ptr.ChoiceRV)
273-
def jax_funcify_choice(op):
273+
def jax_funcify_choice(op: ptr.ChoiceRV):
274274
"""JAX implementation of `ChoiceRV`."""
275275

276+
p_none = op.p_none
277+
276278
def sample_fn(rng, size, dtype, *parameters):
277279
rng_key = rng["jax_state"]
278280
rng_key, sampling_key = jax.random.split(rng_key, 2)
279-
(a, p, replace) = parameters
281+
if p_none:
282+
(a, replace) = parameters
283+
p = None
284+
else:
285+
(a, p, replace) = parameters
280286
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
281287
rng["jax_state"] = rng_key
282288
return (rng, smpl_value)

0 commit comments

Comments
 (0)