Skip to content

Commit d4e5db1

Browse files
committed
Avoid spurious deprecation warning signature/extended_signature in CustomDist
Also allow multivariate CustomDist to be created when signature suffices to infer core shape.
1 parent 29eef08 commit d4e5db1

File tree

2 files changed

+35
-53
lines changed

2 files changed

+35
-53
lines changed

pymc/distributions/custom.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.tensor.random.op import RandomVariable
2828
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
2929
from pytensor.tensor.random.utils import normalize_size_param
30-
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
30+
from pytensor.tensor.utils import safe_signature
3131

3232
from pymc.distributions.distribution import (
3333
Distribution,
@@ -108,19 +108,9 @@ def dist(
108108
class_name: str = "CustomDist",
109109
**kwargs,
110110
):
111-
if ndim_supp is None or ndims_params is None:
112-
if signature is None:
113-
ndim_supp = 0
114-
ndims_params = [0] * len(dist_params)
115-
else:
116-
inputs, outputs = _parse_gufunc_signature(signature)
117-
ndim_supp = max(len(out) for out in outputs)
118-
ndims_params = [len(inp) for inp in inputs]
119-
120-
if ndim_supp > 0:
121-
raise NotImplementedError(
122-
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
123-
)
111+
if ndim_supp is None and signature is None:
112+
# Assume a scalar distribution
113+
signature = safe_signature([0] * len(dist_params), [0])
124114

125115
dist_params = [as_tensor_variable(param) for param in dist_params]
126116

@@ -148,6 +138,7 @@ def dist(
148138
support_point=support_point,
149139
ndim_supp=ndim_supp,
150140
ndims_params=ndims_params,
141+
signature=signature,
151142
dtype=dtype,
152143
class_name=class_name,
153144
**kwargs,
@@ -161,8 +152,9 @@ def rv_op(
161152
logcdf: Callable | None,
162153
random: Callable | None,
163154
support_point: Callable | None,
164-
ndim_supp: int,
165-
ndims_params: Sequence[int],
155+
signature: str | None,
156+
ndim_supp: int | None,
157+
ndims_params: Sequence[int] | None,
166158
dtype: str,
167159
class_name: str,
168160
**kwargs,
@@ -175,6 +167,7 @@ def rv_op(
175167
inplace=False,
176168
ndim_supp=ndim_supp,
177169
ndims_params=ndims_params,
170+
signature=signature,
178171
dtype=dtype,
179172
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
180173
# Specific to CustomDist
@@ -344,7 +337,7 @@ def change_custom_dist_size(op, rv, new_size, expand):
344337
new_rv_op = rv_type(
345338
inputs=[*dummy_params, *rngs],
346339
outputs=[dummy_rv, *rngs_updates],
347-
signature=signature,
340+
extended_signature=extended_signature,
348341
)
349342
new_rv = new_rv_op(new_size, *dist_params, *rngs)
350343

@@ -357,13 +350,13 @@ def change_custom_dist_size(op, rv, new_size, expand):
357350

358351
inputs = [*dummy_params, *rngs]
359352
outputs = [dummy_rv, *rngs_updates]
360-
signature = cls._infer_final_signature(
353+
extended_signature = cls._infer_final_signature(
361354
signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs)
362355
)
363356
rv_op = rv_type(
364357
inputs=inputs,
365358
outputs=outputs,
366-
signature=signature,
359+
extended_signature=extended_signature,
367360
)
368361
return rv_op(size, *dist_params, *rngs)
369362

tests/distributions/test_custom.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
from pymc.step_methods import Metropolis
5454
from pymc.testing import assert_support_point_is_expected
5555

56+
# Raise for any warnings in this file
57+
pytestmark = pytest.mark.filterwarnings("error")
58+
5659

5760
class TestCustomDist:
5861
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
@@ -105,24 +108,24 @@ def test_custom_dist_without_random(self):
105108
with pytest.raises(NotImplementedError):
106109
sample_posterior_predictive(idata, model=model)
107110

108-
@pytest.mark.xfail(
109-
NotImplementedError,
110-
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
111-
)
112111
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
113112
def test_custom_dist_with_random_multivariate(self, size):
113+
def random(mu, rng, size):
114+
return rng.multivariate_normal(
115+
mean=mu.ravel(),
116+
cov=np.eye(mu.shape[-1]),
117+
size=size,
118+
)
119+
114120
supp_shape = 5
115121
with Model() as model:
116122
mu = Normal("mu", 0, 1, size=supp_shape)
117123
obs = CustomDist(
118124
"custom_dist",
119125
mu,
120-
random=lambda mu, rng=None, size=None: rng.multivariate_normal(
121-
mean=mu, cov=np.eye(len(mu)), size=size
122-
),
126+
random=random,
123127
observed=np.random.randn(100, *size, supp_shape),
124-
ndims_params=[1],
125-
ndim_supp=1,
128+
signature="(n)->(n)",
126129
)
127130

128131
assert isinstance(obs.owner.op, CustomDistRV)
@@ -156,20 +159,16 @@ def test_custom_dist_old_api_error(self):
156159
):
157160
CustomDist("a", lambda x: x)
158161

159-
@pytest.mark.xfail(
160-
NotImplementedError,
161-
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
162-
)
163162
@pytest.mark.parametrize("size", [None, (), (2,)], ids=str)
164163
def test_custom_dist_multivariate_logp(self, size):
165164
supp_shape = 5
166165
with Model() as model:
167166

168167
def logp(value, mu):
169-
return MvNormal.logp(value, mu, pt.eye(mu.shape[0]))
168+
return MvNormal.logp(value, mu, pt.eye(mu.shape[-1]))
170169

171170
mu = Normal("mu", size=supp_shape)
172-
a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size)
171+
a = CustomDist("a", mu, logp=logp, signature="(n)->(n)", size=size)
173172

174173
assert isinstance(a.owner.op, CustomDistRV)
175174
mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX)
@@ -219,10 +218,6 @@ def density_support_point(rv, size, mu):
219218
assert evaled_support_point.shape == to_tuple(size)
220219
assert np.all(evaled_support_point == mu_val)
221220

222-
@pytest.mark.xfail(
223-
NotImplementedError,
224-
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
225-
)
226221
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
227222
def test_custom_dist_custom_support_point_multivariate(self, size):
228223
def density_support_point(rv, size, mu):
@@ -235,19 +230,14 @@ def density_support_point(rv, size, mu):
235230
"a",
236231
mu,
237232
support_point=density_support_point,
238-
ndims_params=[1],
239-
ndim_supp=1,
233+
signature="(n)->(n)",
240234
size=size,
241235
)
242236
assert isinstance(a.owner.op, CustomDistRV)
243237
evaled_support_point = support_point(a).eval({mu: mu_val})
244238
assert evaled_support_point.shape == (*to_tuple(size), 5)
245239
assert np.all(evaled_support_point == mu_val)
246240

247-
@pytest.mark.xfail(
248-
NotImplementedError,
249-
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
250-
)
251241
@pytest.mark.parametrize(
252242
"with_random, size",
253243
[
@@ -267,21 +257,14 @@ def _random(mu, rng=None, size=None):
267257
else:
268258
random = None
269259

270-
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX)
271260
with Model():
272261
mu = Normal("mu", size=5)
273-
a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size)
262+
a = CustomDist("a", mu, random=random, signature="(n)->(n)", size=size)
274263
assert isinstance(a.owner.op, CustomDistRV)
275264
if with_random:
276-
evaled_support_point = support_point(a).eval({mu: mu_val})
265+
evaled_support_point = support_point(a).eval()
277266
assert evaled_support_point.shape == (*to_tuple(size), 5)
278267
assert np.all(evaled_support_point == 0)
279-
else:
280-
with pytest.raises(
281-
TypeError,
282-
match="Cannot safely infer the size of a multivariate random variable's support_point.",
283-
):
284-
evaled_support_point = support_point(a).eval({mu: mu_val})
285268

286269
def test_dist(self):
287270
mu = 1
@@ -300,6 +283,12 @@ def test_dist(self):
300283
x_logp = logp(x, test_value)
301284
assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value))
302285

286+
def test_multivariate_insufficient_signature(self):
287+
with pytest.raises(
288+
NotImplementedError, match="signature is not sufficient to infer the support shape"
289+
):
290+
CustomDist.dist(signature="(n)->(m)")
291+
303292

304293
class TestCustomSymbolicDist:
305294
def test_basic(self):

0 commit comments

Comments
 (0)