-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow non-scalar measurable switch mixtures #6796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,8 +52,9 @@ | |
as_index_constant, | ||
) | ||
|
||
from pymc.logprob.abstract import MeasurableVariable | ||
from pymc.logprob.basic import conditional_logp, logp | ||
from pymc.logprob.mixture import MixtureRV, expand_indices | ||
from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices | ||
from pymc.logprob.rewriting import construct_ir_fgraph | ||
from pymc.logprob.utils import dirac_delta | ||
from pymc.testing import assert_no_rvs | ||
|
@@ -907,7 +908,7 @@ def test_mixture_with_DiracDelta(): | |
assert m_vv in logp_res | ||
|
||
|
||
def test_switch_mixture(): | ||
def test_scalar_switch_mixture(): | ||
srng = pt.random.RandomStream(29833) | ||
|
||
X_rv = srng.normal(-10.0, 0.1, name="X") | ||
|
@@ -919,6 +920,7 @@ def test_switch_mixture(): | |
|
||
# When I_rv == True, X_rv flows through otherwise Y_rv does | ||
Z1_rv = pt.switch(I_rv, X_rv, Y_rv) | ||
Z1_rv.name = "Z1" | ||
|
||
assert Z1_rv.eval({I_rv: 0}) > 5 | ||
assert Z1_rv.eval({I_rv: 1}) < -5 | ||
|
@@ -927,40 +929,90 @@ def test_switch_mixture(): | |
z_vv.name = "z1" | ||
|
||
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) | ||
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) | ||
|
||
assert isinstance(fgraph.outputs[0].owner.op, MixtureRV) | ||
assert not hasattr( | ||
fgraph.outputs[0].tag, "test_value" | ||
) # pytensor.config.compute_test_value == "off" | ||
assert fgraph.outputs[0].name is None | ||
|
||
Z1_rv.name = "Z1" | ||
|
||
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) | ||
|
||
# building the identical graph but with a stack to check that mixture computations are identical | ||
|
||
# building the identical graph but with a stack to check that mixture logps are identical | ||
Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] | ||
|
||
assert Z2_rv.eval({I_rv: 0}) > 5 | ||
assert Z2_rv.eval({I_rv: 1}) < -5 | ||
|
||
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv}) | ||
|
||
assert equal_computations(fgraph.outputs, fgraph2.outputs) | ||
|
||
z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv}) | ||
z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv}) | ||
z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()]) | ||
z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()]) | ||
|
||
# below should follow immediately from the equal_computations assertion above | ||
assert equal_computations([z1_logp_combined], [z2_logp_combined]) | ||
|
||
np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1})) | ||
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) | ||
|
||
|
||
@pytest.mark.parametrize("switch_cond_scalar", (True, False)) | ||
def test_switch_mixture_vector(switch_cond_scalar): | ||
if switch_cond_scalar: | ||
switch_cond = pt.scalar("switch_cond", dtype=bool) | ||
else: | ||
switch_cond = pt.vector("switch_cond", dtype=bool) | ||
true_branch = pt.exp(pt.random.normal(size=(4,))) | ||
false_branch = pt.abs(pt.random.normal(size=(4,))) | ||
|
||
switch = pt.switch(switch_cond, true_branch, false_branch) | ||
switch.name = "switch_mix" | ||
switch_value = switch.clone() | ||
switch_logp = logp(switch, switch_value) | ||
|
||
if switch_cond_scalar: | ||
test_switch_cond = np.array(0, dtype=bool) | ||
else: | ||
test_switch_cond = np.array([0, 1, 0, 1], dtype=bool) | ||
test_switch_value = np.linspace(0.1, 2.5, 4) | ||
np.testing.assert_allclose( | ||
switch_logp.eval({switch_cond: test_switch_cond, switch_value: test_switch_value}), | ||
np.where( | ||
test_switch_cond, | ||
logp(true_branch, test_switch_value).eval(), | ||
logp(false_branch, test_switch_value).eval(), | ||
), | ||
) | ||
|
||
|
||
def test_switch_mixture_measurable_cond_fails(): | ||
"""Test that logprob inference fails when the switch condition is an unvalued measurable variable. | ||
|
||
Otherwise, the logp function would have to marginalize over this variable. | ||
|
||
NOTE: This could be supported in the future, in which case this test can be removed/adapted | ||
""" | ||
cond_var = 1 - pt.random.bernoulli(p=0.5) | ||
true_branch = pt.random.normal() | ||
false_branch = pt.random.normal() | ||
|
||
switch = pt.switch(cond_var, true_branch, false_branch) | ||
with pytest.raises(NotImplementedError, match="Logprob method not implemented for"): | ||
logp(switch, switch.type()) | ||
|
||
|
||
def test_switch_mixture_invalid_bcast(): | ||
"""Test that we don't mark switches where components are broadcasted as measurable""" | ||
valid_switch_cond = pt.vector("switch_cond", dtype=bool) | ||
invalid_switch_cond = pt.matrix("switch_cond", dtype=bool) | ||
|
||
valid_true_branch = pt.exp(pt.random.normal(size=(4,))) | ||
valid_false_branch = pt.abs(pt.random.normal(size=(4,))) | ||
invalid_false_branch = pt.abs(pt.random.normal(size=())) | ||
|
||
valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch) | ||
fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()}) | ||
assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) | ||
|
||
invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch) | ||
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) | ||
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) | ||
|
||
invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch) | ||
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) | ||
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be more clear to rename the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a second thought... rename to what? I feel the names are pretty reasonable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I was just alluding to the redundancy in names (three presences of |
||
|
||
|
||
def test_ifelse_mixture_one_component(): | ||
if_rv = pt.random.bernoulli(0.5, name="if") | ||
scale_rv = pt.random.halfnormal(name="scale") | ||
|
Uh oh!
There was an error while loading. Please reload this page.