@@ -79,7 +79,7 @@ def test_dimensions(self):
79
79
80
80
def test_mixture_list_of_normals (self ):
81
81
with Model () as model :
82
- w = Dirichlet ('w' , floatX (np .ones_like (self .norm_w )))
82
+ w = Dirichlet ('w' , floatX (np .ones_like (self .norm_w )), shape = self . norm_w . size )
83
83
mu = Normal ('mu' , 0. , 10. , shape = self .norm_w .size )
84
84
tau = Gamma ('tau' , 1. , 1. , shape = self .norm_w .size )
85
85
Mixture ('x_obs' , w ,
@@ -98,7 +98,7 @@ def test_mixture_list_of_normals(self):
98
98
99
99
def test_normal_mixture (self ):
100
100
with Model () as model :
101
- w = Dirichlet ('w' , floatX (np .ones_like (self .norm_w )))
101
+ w = Dirichlet ('w' , floatX (np .ones_like (self .norm_w )), shape = self . norm_w . size )
102
102
mu = Normal ('mu' , 0. , 10. , shape = self .norm_w .size )
103
103
tau = Gamma ('tau' , 1. , 1. , shape = self .norm_w .size )
104
104
NormalMixture ('x_obs' , w , mu , tau = tau , observed = self .norm_x )
@@ -135,7 +135,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
135
135
with Model () as model0 :
136
136
mus = Normal ('mus' , shape = comp_shape )
137
137
taus = Gamma ('taus' , alpha = 1 , beta = 1 , shape = comp_shape )
138
- ws = Dirichlet ('ws' , np .ones (ncomp ))
138
+ ws = Dirichlet ('ws' , np .ones (ncomp ), shape = ( ncomp ,) )
139
139
mixture0 = NormalMixture ('m' , w = ws , mu = mus , tau = taus , shape = nd ,
140
140
comp_shape = comp_shape )
141
141
obs0 = NormalMixture ('obs' , w = ws , mu = mus , tau = taus , shape = nd ,
@@ -145,7 +145,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
145
145
with Model () as model1 :
146
146
mus = Normal ('mus' , shape = comp_shape )
147
147
taus = Gamma ('taus' , alpha = 1 , beta = 1 , shape = comp_shape )
148
- ws = Dirichlet ('ws' , np .ones (ncomp ))
148
+ ws = Dirichlet ('ws' , np .ones (ncomp ), shape = ( ncomp ,) )
149
149
comp_dist = [Normal .dist (mu = mus [..., i ], tau = taus [..., i ],
150
150
shape = nd )
151
151
for i in range (ncomp )]
@@ -163,7 +163,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
163
163
# comp_dists.
164
164
mus = Normal ('mus' , shape = comp_shape )
165
165
taus = Gamma ('taus' , alpha = 1 , beta = 1 , shape = comp_shape )
166
- ws = Dirichlet ('ws' , np .ones (ncomp ))
166
+ ws = Dirichlet ('ws' , np .ones (ncomp ), shape = ( ncomp ,) )
167
167
if len (nd ) > 1 :
168
168
if nd [- 1 ] != ncomp :
169
169
with pytest .raises (ValueError ):
@@ -208,7 +208,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
208
208
209
209
def test_poisson_mixture (self ):
210
210
with Model () as model :
211
- w = Dirichlet ('w' , floatX (np .ones_like (self .pois_w )))
211
+ w = Dirichlet ('w' , floatX (np .ones_like (self .pois_w )), shape = self . pois_w . shape )
212
212
mu = Gamma ('mu' , 1. , 1. , shape = self .pois_w .size )
213
213
Mixture ('x_obs' , w , Poisson .dist (mu ), observed = self .pois_x )
214
214
step = Metropolis ()
@@ -224,7 +224,7 @@ def test_poisson_mixture(self):
224
224
225
225
def test_mixture_list_of_poissons (self ):
226
226
with Model () as model :
227
- w = Dirichlet ('w' , floatX (np .ones_like (self .pois_w )))
227
+ w = Dirichlet ('w' , floatX (np .ones_like (self .pois_w )), shape = self . pois_w . shape )
228
228
mu = Gamma ('mu' , 1. , 1. , shape = self .pois_w .size )
229
229
Mixture ('x_obs' , w ,
230
230
[Poisson .dist (mu [0 ]), Poisson .dist (mu [1 ])],
@@ -247,7 +247,7 @@ def test_mixture_of_mvn(self):
247
247
cov2 = np .diag ([2.5 , 3.5 ])
248
248
obs = np .asarray ([[.5 , .5 ], mu1 , mu2 ])
249
249
with Model () as model :
250
- w = Dirichlet ('w' , floatX (np .ones (2 )), transform = None )
250
+ w = Dirichlet ('w' , floatX (np .ones (2 )), transform = None , shape = ( 2 ,) )
251
251
mvncomp1 = MvNormal .dist (mu = mu1 , cov = cov1 )
252
252
mvncomp2 = MvNormal .dist (mu = mu2 , cov = cov2 )
253
253
y = Mixture ('x_obs' , w , [mvncomp1 , mvncomp2 ],
@@ -291,13 +291,13 @@ def test_mixture_of_mixture(self):
291
291
sigma = 1 ,
292
292
shape = nbr )
293
293
# weight vector for the mixtures
294
- g_w = Dirichlet ('g_w' , a = floatX (np .ones (nbr )* 0.0000001 ), transform = None )
295
- l_w = Dirichlet ('l_w' , a = floatX (np .ones (nbr )* 0.0000001 ), transform = None )
294
+ g_w = Dirichlet ('g_w' , a = floatX (np .ones (nbr )* 0.0000001 ), transform = None , shape = ( nbr ,) )
295
+ l_w = Dirichlet ('l_w' , a = floatX (np .ones (nbr )* 0.0000001 ), transform = None , shape = ( nbr ,) )
296
296
# mixture components
297
297
g_mix = Mixture .dist (w = g_w , comp_dists = g_comp )
298
298
l_mix = Mixture .dist (w = l_w , comp_dists = l_comp )
299
299
# mixture of mixtures
300
- mix_w = Dirichlet ('mix_w' , a = floatX (np .ones (2 )), transform = None )
300
+ mix_w = Dirichlet ('mix_w' , a = floatX (np .ones (2 )), transform = None , shape = ( 2 ,) )
301
301
mix = Mixture ('mix' , w = mix_w ,
302
302
comp_dists = [g_mix , l_mix ],
303
303
observed = np .exp (self .norm_x ))
@@ -378,7 +378,7 @@ def build_toy_dataset(N, K):
378
378
X , y = build_toy_dataset (N , K )
379
379
380
380
with pm .Model () as model :
381
- pi = pm .Dirichlet ('pi' , np .ones (K ))
381
+ pi = pm .Dirichlet ('pi' , np .ones (K ), shape = ( K ,) )
382
382
383
383
comp_dist = []
384
384
mu = []
0 commit comments