52
52
as_index_constant ,
53
53
)
54
54
55
+ from pymc .logprob .abstract import MeasurableVariable
55
56
from pymc .logprob .basic import conditional_logp , logp
56
- from pymc .logprob .mixture import MixtureRV , expand_indices
57
+ from pymc .logprob .mixture import MeasurableSwitchMixture , MixtureRV , expand_indices
57
58
from pymc .logprob .rewriting import construct_ir_fgraph
58
59
from pymc .logprob .utils import dirac_delta
59
60
from pymc .testing import assert_no_rvs
@@ -907,7 +908,7 @@ def test_mixture_with_DiracDelta():
907
908
assert m_vv in logp_res
908
909
909
910
910
- def test_switch_mixture ():
911
+ def test_scalar_switch_mixture ():
911
912
srng = pt .random .RandomStream (29833 )
912
913
913
914
X_rv = srng .normal (- 10.0 , 0.1 , name = "X" )
@@ -919,6 +920,7 @@ def test_switch_mixture():
919
920
920
921
# When I_rv == True, X_rv flows through otherwise Y_rv does
921
922
Z1_rv = pt .switch (I_rv , X_rv , Y_rv )
923
+ Z1_rv .name = "Z1"
922
924
923
925
assert Z1_rv .eval ({I_rv : 0 }) > 5
924
926
assert Z1_rv .eval ({I_rv : 1 }) < - 5
@@ -927,40 +929,90 @@ def test_switch_mixture():
927
929
z_vv .name = "z1"
928
930
929
931
fgraph , _ , _ = construct_ir_fgraph ({Z1_rv : z_vv , I_rv : i_vv })
932
+ assert isinstance (fgraph .outputs [0 ].owner .op , MeasurableSwitchMixture )
930
933
931
- assert isinstance (fgraph .outputs [0 ].owner .op , MixtureRV )
932
- assert not hasattr (
933
- fgraph .outputs [0 ].tag , "test_value"
934
- ) # pytensor.config.compute_test_value == "off"
935
- assert fgraph .outputs [0 ].name is None
936
-
937
- Z1_rv .name = "Z1"
938
-
939
- fgraph , _ , _ = construct_ir_fgraph ({Z1_rv : z_vv , I_rv : i_vv })
940
-
941
- # building the identical graph but with a stack to check that mixture computations are identical
942
-
934
+ # building the identical graph but with a stack to check that mixture logps are identical
943
935
Z2_rv = pt .stack ((Y_rv , X_rv ))[I_rv ]
944
936
945
937
assert Z2_rv .eval ({I_rv : 0 }) > 5
946
938
assert Z2_rv .eval ({I_rv : 1 }) < - 5
947
939
948
- fgraph2 , _ , _ = construct_ir_fgraph ({Z2_rv : z_vv , I_rv : i_vv })
949
-
950
- assert equal_computations (fgraph .outputs , fgraph2 .outputs )
951
-
952
940
z1_logp = conditional_logp ({Z1_rv : z_vv , I_rv : i_vv })
953
941
z2_logp = conditional_logp ({Z2_rv : z_vv , I_rv : i_vv })
954
942
z1_logp_combined = pt .sum ([pt .sum (factor ) for factor in z1_logp .values ()])
955
943
z2_logp_combined = pt .sum ([pt .sum (factor ) for factor in z2_logp .values ()])
956
-
957
- # below should follow immediately from the equal_computations assertion above
958
- assert equal_computations ([z1_logp_combined ], [z2_logp_combined ])
959
-
960
944
np .testing .assert_almost_equal (0.69049938 , z1_logp_combined .eval ({z_vv : - 10 , i_vv : 1 }))
961
945
np .testing .assert_almost_equal (0.69049938 , z2_logp_combined .eval ({z_vv : - 10 , i_vv : 1 }))
962
946
963
947
948
+ @pytest .mark .parametrize ("switch_cond_scalar" , (True , False ))
949
+ def test_switch_mixture_vector (switch_cond_scalar ):
950
+ if switch_cond_scalar :
951
+ switch_cond = pt .scalar ("switch_cond" , dtype = bool )
952
+ else :
953
+ switch_cond = pt .vector ("switch_cond" , dtype = bool )
954
+ true_branch = pt .exp (pt .random .normal (size = (4 ,)))
955
+ false_branch = pt .abs (pt .random .normal (size = (4 ,)))
956
+
957
+ switch = pt .switch (switch_cond , true_branch , false_branch )
958
+ switch .name = "switch_mix"
959
+ switch_value = switch .clone ()
960
+ switch_logp = logp (switch , switch_value )
961
+
962
+ if switch_cond_scalar :
963
+ test_switch_cond = np .array (0 , dtype = bool )
964
+ else :
965
+ test_switch_cond = np .array ([0 , 1 , 0 , 1 ], dtype = bool )
966
+ test_switch_value = np .linspace (0.1 , 2.5 , 4 )
967
+ np .testing .assert_allclose (
968
+ switch_logp .eval ({switch_cond : test_switch_cond , switch_value : test_switch_value }),
969
+ np .where (
970
+ test_switch_cond ,
971
+ logp (true_branch , test_switch_value ).eval (),
972
+ logp (false_branch , test_switch_value ).eval (),
973
+ ),
974
+ )
975
+
976
+
977
+ def test_switch_mixture_measurable_cond_fails ():
978
+ """Test that logprob inference fails when the switch condition is an unvalued measurable variable.
979
+
980
+ Otherwise, the logp function would have to marginalize over this variable.
981
+
982
+ NOTE: This could be supported in the future, in which case this test can be removed/adapted
983
+ """
984
+ cond_var = 1 - pt .random .bernoulli (p = 0.5 )
985
+ true_branch = pt .random .normal ()
986
+ false_branch = pt .random .normal ()
987
+
988
+ switch = pt .switch (cond_var , true_branch , false_branch )
989
+ with pytest .raises (NotImplementedError , match = "Logprob method not implemented for" ):
990
+ logp (switch , switch .type ())
991
+
992
+
993
+ def test_switch_mixture_invalid_bcast ():
994
+ """Test that we don't mark switches where components are broadcasted as measurable"""
995
+ valid_switch_cond = pt .vector ("switch_cond" , dtype = bool )
996
+ invalid_switch_cond = pt .matrix ("switch_cond" , dtype = bool )
997
+
998
+ valid_true_branch = pt .exp (pt .random .normal (size = (4 ,)))
999
+ valid_false_branch = pt .abs (pt .random .normal (size = (4 ,)))
1000
+ invalid_false_branch = pt .abs (pt .random .normal (size = ()))
1001
+
1002
+ valid_mix = pt .switch (valid_switch_cond , valid_true_branch , valid_false_branch )
1003
+ fgraph , _ , _ = construct_ir_fgraph ({valid_mix : valid_mix .type ()})
1004
+ assert isinstance (fgraph .outputs [0 ].owner .op , MeasurableVariable )
1005
+ assert isinstance (fgraph .outputs [0 ].owner .op , MeasurableSwitchMixture )
1006
+
1007
+ invalid_mix = pt .switch (invalid_switch_cond , valid_true_branch , valid_false_branch )
1008
+ fgraph , _ , _ = construct_ir_fgraph ({invalid_mix : invalid_mix .type ()})
1009
+ assert not isinstance (fgraph .outputs [0 ].owner .op , MeasurableVariable )
1010
+
1011
+ invalid_mix = pt .switch (valid_switch_cond , valid_true_branch , invalid_false_branch )
1012
+ fgraph , _ , _ = construct_ir_fgraph ({invalid_mix : invalid_mix .type ()})
1013
+ assert not isinstance (fgraph .outputs [0 ].owner .op , MeasurableVariable )
1014
+
1015
+
964
1016
def test_ifelse_mixture_one_component ():
965
1017
if_rv = pt .random .bernoulli (0.5 , name = "if" )
966
1018
scale_rv = pt .random .halfnormal (name = "scale" )
0 commit comments