Skip to content

Commit f551718

Browse files
committed
Fix bug in implicit_size_from_params
1 parent 416346e commit f551718

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

pymc/distributions/shape_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def implicit_size_from_params(
461461
for param, ndim in zip(params, ndims_params):
462462
batch_shape = list(param.shape[:-ndim] if ndim > 0 else param.shape)
463463
# Overwrite broadcastable dims
464-
for i, broadcastable in enumerate(param.type.broadcastable):
464+
for i, broadcastable in enumerate(param.type.broadcastable[: len(batch_shape)]):
465465
if broadcastable:
466466
batch_shape[i] = 1
467467
batch_shapes.append(batch_shape)

tests/distributions/test_shape_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
convert_size,
3636
get_support_shape,
3737
get_support_shape_1d,
38+
implicit_size_from_params,
3839
rv_size_is_none,
3940
)
4041
from pymc.model import Model
42+
from pymc.pytensorf import constant_fold
4143

4244
test_shapes = [
4345
((), (1,), (4,), (5, 4)),
@@ -630,3 +632,10 @@ def test_get_support_shape(
630632
assert (f() == expected_support_shape).all()
631633
with pytest.raises(AssertionError, match="support_shape does not match"):
632634
inferred_support_shape.eval()
635+
636+
637+
def test_implicit_size_from_params():
638+
x = pt.tensor(shape=(5, 1))
639+
y = pt.tensor(shape=(3, 3))
640+
res = implicit_size_from_params(x, y, ndims_params=[1, 2])
641+
assert constant_fold([res]) == (5,)

0 commit comments

Comments
 (0)