Skip to content

Allow non-scalar measurable switch mixtures #6796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf
from pymc.math import logdiffexp
from pymc.util import check_dist_not_registered
Expand All @@ -64,9 +64,6 @@ def update(self, node: Node):
return {node.inputs[-1]: node.outputs[0]}


MeasurableVariable.register(TruncatedRV)


@singledispatch
def _truncated(op: Op, lower, upper, size, *params):
"""Return the truncated equivalent of another `RandomVariable`."""
Expand Down
86 changes: 53 additions & 33 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from pytensor.graph.op import Op, compute_test_value
from pytensor.graph.rewriting.basic import node_rewriter, pre_greedy_node_rewriter
from pytensor.ifelse import IfElse, ifelse
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
from pytensor.tensor.basic import Join, MakeVector, switch
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
Expand All @@ -55,15 +57,19 @@
AdvancedSubtensor,
AdvancedSubtensor1,
as_index_literal,
as_nontensor_scalar,
get_canonical_form_slice,
is_basic_idx,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import (
PreserveRVMappings,
assume_measured_ir_outputs,
Expand Down Expand Up @@ -325,37 +331,6 @@ def find_measurable_index_mixture(fgraph, node):
return [new_mixture_rv]


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

old_mixture_rv = node.default_output()
idx, *components = node.inputs

if rv_map_feature.request_measurable(components) != components:
return None

mix_op = MixtureRV(
2,
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
)
new_mixture_rv = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1])
).default_output()

if pytensor.config.compute_test_value != "off":
if not hasattr(old_mixture_rv.tag, "test_value"):
compute_test_value(node)

new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value

return [new_mixture_rv]


@_logprob.register(MixtureRV)
def logprob_MixtureRV(
op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs
Expand Down Expand Up @@ -433,6 +408,51 @@ def logprob_MixtureRV(
return logp_val


class MeasurableSwitchMixture(MeasurableElemwise):
valid_scalar_types = (Switch,)


measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch)


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

switch_cond, *components = node.inputs

# We don't support broadcasting of components, as that yields dependent (identical) values.
# The current logp implementation assumes all component values are independent.
# Broadcasting of the switch condition is fine
out_bcast = node.outputs[0].type.broadcastable
if any(comp.type.broadcastable != out_bcast for comp in components):
return None

# Check that `switch_cond` is not potentially measurable
valued_rvs = rv_map_feature.rv_values.keys()
if check_potential_measurability([switch_cond], valued_rvs):
return None

if rv_map_feature.request_measurable(components) != components:
return None

return [measurable_switch_mixture(switch_cond, *components)]


@_logprob.register(MeasurableSwitchMixture)
def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs):
[value] = values

return switch(
switch_cond,
_logprob_helper(component_true, value),
_logprob_helper(component_false, value),
)


measurable_ir_rewrites_db.register(
"find_measurable_index_mixture",
find_measurable_index_mixture,
Expand Down
96 changes: 74 additions & 22 deletions tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@
as_index_constant,
)

from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.mixture import MixtureRV, expand_indices
from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.utils import dirac_delta
from pymc.testing import assert_no_rvs
Expand Down Expand Up @@ -907,7 +908,7 @@ def test_mixture_with_DiracDelta():
assert m_vv in logp_res


def test_switch_mixture():
def test_scalar_switch_mixture():
srng = pt.random.RandomStream(29833)

X_rv = srng.normal(-10.0, 0.1, name="X")
Expand All @@ -919,6 +920,7 @@ def test_switch_mixture():

# When I_rv == True, X_rv flows through otherwise Y_rv does
Z1_rv = pt.switch(I_rv, X_rv, Y_rv)
Z1_rv.name = "Z1"

assert Z1_rv.eval({I_rv: 0}) > 5
assert Z1_rv.eval({I_rv: 1}) < -5
Expand All @@ -927,40 +929,90 @@ def test_switch_mixture():
z_vv.name = "z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture)

assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
assert not hasattr(
fgraph.outputs[0].tag, "test_value"
) # pytensor.config.compute_test_value == "off"
assert fgraph.outputs[0].name is None

Z1_rv.name = "Z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})

# building the identical graph but with a stack to check that mixture computations are identical

# building the identical graph but with a stack to check that mixture logps are identical
Z2_rv = pt.stack((Y_rv, X_rv))[I_rv]

assert Z2_rv.eval({I_rv: 0}) > 5
assert Z2_rv.eval({I_rv: 1}) < -5

fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})

assert equal_computations(fgraph.outputs, fgraph2.outputs)

z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv})
z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv})
z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()])
z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()])

# below should follow immediately from the equal_computations assertion above
assert equal_computations([z1_logp_combined], [z2_logp_combined])

np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1}))
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1}))


@pytest.mark.parametrize("switch_cond_scalar", (True, False))
def test_switch_mixture_vector(switch_cond_scalar):
if switch_cond_scalar:
switch_cond = pt.scalar("switch_cond", dtype=bool)
else:
switch_cond = pt.vector("switch_cond", dtype=bool)
true_branch = pt.exp(pt.random.normal(size=(4,)))
false_branch = pt.abs(pt.random.normal(size=(4,)))

switch = pt.switch(switch_cond, true_branch, false_branch)
switch.name = "switch_mix"
switch_value = switch.clone()
switch_logp = logp(switch, switch_value)

if switch_cond_scalar:
test_switch_cond = np.array(0, dtype=bool)
else:
test_switch_cond = np.array([0, 1, 0, 1], dtype=bool)
test_switch_value = np.linspace(0.1, 2.5, 4)
np.testing.assert_allclose(
switch_logp.eval({switch_cond: test_switch_cond, switch_value: test_switch_value}),
np.where(
test_switch_cond,
logp(true_branch, test_switch_value).eval(),
logp(false_branch, test_switch_value).eval(),
),
)


def test_switch_mixture_measurable_cond_fails():
"""Test that logprob inference fails when the switch condition is an unvalued measurable variable.

Otherwise, the logp function would have to marginalize over this variable.

NOTE: This could be supported in the future, in which case this test can be removed/adapted
"""
cond_var = 1 - pt.random.bernoulli(p=0.5)
true_branch = pt.random.normal()
false_branch = pt.random.normal()

switch = pt.switch(cond_var, true_branch, false_branch)
with pytest.raises(NotImplementedError, match="Logprob method not implemented for"):
logp(switch, switch.type())


def test_switch_mixture_invalid_bcast():
"""Test that we don't mark switches where components are broadcasted as measurable"""
valid_switch_cond = pt.vector("switch_cond", dtype=bool)
invalid_switch_cond = pt.matrix("switch_cond", dtype=bool)

valid_true_branch = pt.exp(pt.random.normal(size=(4,)))
valid_false_branch = pt.abs(pt.random.normal(size=(4,)))
invalid_false_branch = pt.abs(pt.random.normal(size=()))

valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch)
fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()})
assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture)

invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch)
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()})
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)

invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch)
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()})
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more clear to rename the fgraphs and invalid_mixes in this test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought... rename to what? I feel the names are pretty reasonable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was just alluding to the redundancy in names (three presences of fgraph, two of invalid_mix).



def test_ifelse_mixture_one_component():
if_rv = pt.random.bernoulli(0.5, name="if")
scale_rv = pt.random.halfnormal(name="scale")
Expand Down