Skip to content

Commit cb2c77f

Browse files
committed
Allow non-scalar measurable switch mixtures
1 parent 96adf54 commit cb2c77f

File tree

2 files changed

+127
-55
lines changed

2 files changed

+127
-55
lines changed

pymc/logprob/mixture.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from pytensor.graph.op import Op, compute_test_value
4545
from pytensor.graph.rewriting.basic import node_rewriter, pre_greedy_node_rewriter
4646
from pytensor.ifelse import IfElse, ifelse
47+
from pytensor.scalar import Switch
48+
from pytensor.scalar import switch as scalar_switch
4749
from pytensor.tensor.basic import Join, MakeVector, switch
4850
from pytensor.tensor.random.rewriting import (
4951
local_dimshuffle_rv_lift,
@@ -55,15 +57,19 @@
5557
AdvancedSubtensor,
5658
AdvancedSubtensor1,
5759
as_index_literal,
58-
as_nontensor_scalar,
5960
get_canonical_form_slice,
6061
is_basic_idx,
6162
)
6263
from pytensor.tensor.type import TensorType
6364
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
6465
from pytensor.tensor.var import TensorVariable
6566

66-
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
67+
from pymc.logprob.abstract import (
68+
MeasurableElemwise,
69+
MeasurableVariable,
70+
_logprob,
71+
_logprob_helper,
72+
)
6773
from pymc.logprob.rewriting import (
6874
PreserveRVMappings,
6975
assume_measured_ir_outputs,
@@ -325,37 +331,6 @@ def find_measurable_index_mixture(fgraph, node):
325331
return [new_mixture_rv]
326332

327333

328-
@node_rewriter([switch])
329-
def find_measurable_switch_mixture(fgraph, node):
330-
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
331-
332-
if rv_map_feature is None:
333-
return None # pragma: no cover
334-
335-
old_mixture_rv = node.default_output()
336-
idx, *components = node.inputs
337-
338-
if rv_map_feature.request_measurable(components) != components:
339-
return None
340-
341-
mix_op = MixtureRV(
342-
2,
343-
old_mixture_rv.dtype,
344-
old_mixture_rv.broadcastable,
345-
)
346-
new_mixture_rv = mix_op.make_node(
347-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1])
348-
).default_output()
349-
350-
if pytensor.config.compute_test_value != "off":
351-
if not hasattr(old_mixture_rv.tag, "test_value"):
352-
compute_test_value(node)
353-
354-
new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value
355-
356-
return [new_mixture_rv]
357-
358-
359334
@_logprob.register(MixtureRV)
360335
def logprob_MixtureRV(
361336
op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs
@@ -433,6 +408,51 @@ def logprob_MixtureRV(
433408
return logp_val
434409

435410

411+
class MeasurableSwitchMixture(MeasurableElemwise):
412+
valid_scalar_types = (Switch,)
413+
414+
415+
measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch)
416+
417+
418+
@node_rewriter([switch])
419+
def find_measurable_switch_mixture(fgraph, node):
420+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
421+
422+
if rv_map_feature is None:
423+
return None # pragma: no cover
424+
425+
switch_cond, *components = node.inputs
426+
427+
# We don't support broadcasting of components, as that yields dependent (identical) values.
428+
# The current logp implementation assumes all component values are independent.
429+
# Broadcasting of the switch condition is fine
430+
out_bcast = node.outputs[0].type.broadcastable
431+
if any(comp.type.broadcastable != out_bcast for comp in components):
432+
return None
433+
434+
# Check that `switch_cond` is not potentially measurable
435+
valued_rvs = rv_map_feature.rv_values.keys()
436+
if check_potential_measurability([switch_cond], valued_rvs):
437+
return None
438+
439+
if rv_map_feature.request_measurable(components) != components:
440+
return None
441+
442+
return [measurable_switch_mixture(switch_cond, *components)]
443+
444+
445+
@_logprob.register(MeasurableSwitchMixture)
446+
def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs):
447+
[value] = values
448+
449+
return switch(
450+
switch_cond,
451+
_logprob_helper(component_true, value),
452+
_logprob_helper(component_false, value),
453+
)
454+
455+
436456
measurable_ir_rewrites_db.register(
437457
"find_measurable_index_mixture",
438458
find_measurable_index_mixture,

tests/logprob/test_mixture.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@
5252
as_index_constant,
5353
)
5454

55+
from pymc.logprob.abstract import MeasurableVariable
5556
from pymc.logprob.basic import conditional_logp, logp
56-
from pymc.logprob.mixture import MixtureRV, expand_indices
57+
from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices
5758
from pymc.logprob.rewriting import construct_ir_fgraph
5859
from pymc.logprob.utils import dirac_delta
5960
from pymc.testing import assert_no_rvs
@@ -907,7 +908,7 @@ def test_mixture_with_DiracDelta():
907908
assert m_vv in logp_res
908909

909910

910-
def test_switch_mixture():
911+
def test_scalar_switch_mixture():
911912
srng = pt.random.RandomStream(29833)
912913

913914
X_rv = srng.normal(-10.0, 0.1, name="X")
@@ -919,6 +920,7 @@ def test_switch_mixture():
919920

920921
# When I_rv == True, X_rv flows through otherwise Y_rv does
921922
Z1_rv = pt.switch(I_rv, X_rv, Y_rv)
923+
Z1_rv.name = "Z1"
922924

923925
assert Z1_rv.eval({I_rv: 0}) > 5
924926
assert Z1_rv.eval({I_rv: 1}) < -5
@@ -927,40 +929,90 @@ def test_switch_mixture():
927929
z_vv.name = "z1"
928930

929931
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
932+
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture)
930933

931-
assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
932-
assert not hasattr(
933-
fgraph.outputs[0].tag, "test_value"
934-
) # pytensor.config.compute_test_value == "off"
935-
assert fgraph.outputs[0].name is None
936-
937-
Z1_rv.name = "Z1"
938-
939-
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
940-
941-
# building the identical graph but with a stack to check that mixture computations are identical
942-
934+
# building the identical graph but with a stack to check that mixture logps are identical
943935
Z2_rv = pt.stack((Y_rv, X_rv))[I_rv]
944936

945937
assert Z2_rv.eval({I_rv: 0}) > 5
946938
assert Z2_rv.eval({I_rv: 1}) < -5
947939

948-
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
949-
950-
assert equal_computations(fgraph.outputs, fgraph2.outputs)
951-
952940
z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv})
953941
z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv})
954942
z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()])
955943
z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()])
956-
957-
# below should follow immediately from the equal_computations assertion above
958-
assert equal_computations([z1_logp_combined], [z2_logp_combined])
959-
960944
np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1}))
961945
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1}))
962946

963947

948+
@pytest.mark.parametrize("switch_cond_scalar", (True, False))
949+
def test_switch_mixture_vector(switch_cond_scalar):
950+
if switch_cond_scalar:
951+
switch_cond = pt.scalar("switch_cond", dtype=bool)
952+
else:
953+
switch_cond = pt.vector("switch_cond", dtype=bool)
954+
true_branch = pt.exp(pt.random.normal(size=(4,)))
955+
false_branch = pt.abs(pt.random.normal(size=(4,)))
956+
957+
switch = pt.switch(switch_cond, true_branch, false_branch)
958+
switch.name = "switch_mix"
959+
switch_value = switch.clone()
960+
switch_logp = logp(switch, switch_value)
961+
962+
if switch_cond_scalar:
963+
test_switch_cond = np.array(0, dtype=bool)
964+
else:
965+
test_switch_cond = np.array([0, 1, 0, 1], dtype=bool)
966+
test_switch_value = np.linspace(0.1, 2.5, 4)
967+
np.testing.assert_allclose(
968+
switch_logp.eval({switch_cond: test_switch_cond, switch_value: test_switch_value}),
969+
np.where(
970+
test_switch_cond,
971+
logp(true_branch, test_switch_value).eval(),
972+
logp(false_branch, test_switch_value).eval(),
973+
),
974+
)
975+
976+
977+
def test_switch_mixture_measurable_cond_fails():
978+
"""Test that logprob inference fails when the switch condition is an unvalued measurable variable.
979+
980+
Otherwise, the logp function would have to marginalize over this variable.
981+
982+
NOTE: This could be supported in the future, in which case this test can be removed/adapted
983+
"""
984+
cond_var = 1 - pt.random.bernoulli(p=0.5)
985+
true_branch = pt.random.normal()
986+
false_branch = pt.random.normal()
987+
988+
switch = pt.switch(cond_var, true_branch, false_branch)
989+
with pytest.raises(NotImplementedError, match="Logprob method not implemented for"):
990+
logp(switch, switch.type())
991+
992+
993+
def test_switch_mixture_invalid_bcast():
994+
"""Test that we don't mark switches where components are broadcasted as measurable"""
995+
valid_switch_cond = pt.vector("switch_cond", dtype=bool)
996+
invalid_switch_cond = pt.matrix("switch_cond", dtype=bool)
997+
998+
valid_true_branch = pt.exp(pt.random.normal(size=(4,)))
999+
valid_false_branch = pt.abs(pt.random.normal(size=(4,)))
1000+
invalid_false_branch = pt.abs(pt.random.normal(size=()))
1001+
1002+
valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch)
1003+
fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()})
1004+
assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)
1005+
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture)
1006+
1007+
invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch)
1008+
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()})
1009+
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)
1010+
1011+
invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch)
1012+
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()})
1013+
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)
1014+
1015+
9641016
def test_ifelse_mixture_one_component():
9651017
if_rv = pt.random.bernoulli(0.5, name="if")
9661018
scale_rv = pt.random.halfnormal(name="scale")

0 commit comments

Comments
 (0)