Skip to content

Commit b6eeec8

Browse files
mohammed052ricardoV94
authored andcommitted
Remove unused comp_shape from NormalMixture
1 parent a033261 commit b6eeec8

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

pymc/distributions/mixture.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -524,10 +524,6 @@ class NormalMixture:
524524
the component standard deviations
525525
tau : tensor_like of float
526526
the component precisions
527-
comp_shape : shape of the Normal component
528-
notice that it should be different than the shape
529-
of the mixture distribution, with the last axis representing
530-
the number of components.
531527
532528
Notes
533529
-----
@@ -554,16 +550,16 @@ class NormalMixture:
554550
y = pm.NormalMixture("y", w=weights, mu=μ, sigma=σ, observed=data)
555551
"""
556552

557-
def __new__(cls, name, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs):
553+
def __new__(cls, name, w, mu, sigma=None, tau=None, **kwargs):
558554
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
559555

560-
return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
556+
return Mixture(name, w, Normal.dist(mu, sigma=sigma), **kwargs)
561557

562558
@classmethod
563-
def dist(cls, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs):
559+
def dist(cls, w, mu, sigma=None, tau=None, **kwargs):
564560
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
565561

566-
return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
562+
return Mixture.dist(w, Normal.dist(mu, sigma=sigma), **kwargs)
567563

568564

569565
def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):

tests/distributions/test_mixture.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,8 @@ def test_normal_mixture_nd(self, seeded_test, nd, ncomp):
820820
mus = Normal("mus", shape=comp_shape)
821821
taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape)
822822
ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,))
823-
mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape)
824-
obs0 = NormalMixture(
825-
"obs", w=ws, mu=mus, tau=taus, comp_shape=comp_shape, observed=observed
826-
)
823+
mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd)
824+
obs0 = NormalMixture("obs", w=ws, mu=mus, tau=taus, observed=observed)
827825

828826
with Model() as model1:
829827
mus = Normal("mus", shape=comp_shape)
@@ -867,7 +865,6 @@ def ref_rand(size, w, mu, sigma):
867865
"mu": Domain([[0.05, 2.5], [-5.0, 1.0]], edges=(None, None)),
868866
"sigma": Domain([[1, 1], [1.5, 2.0]], edges=(None, None)),
869867
},
870-
extra_args={"comp_shape": 2},
871868
size=1000,
872869
ref_rand=ref_rand,
873870
)
@@ -878,7 +875,6 @@ def ref_rand(size, w, mu, sigma):
878875
"mu": Domain([[-5.0, 1.0, 2.5]], edges=(None, None)),
879876
"sigma": Domain([[1.5, 2.0, 3.0]], edges=(None, None)),
880877
},
881-
extra_args={"comp_shape": 3},
882878
size=1000,
883879
ref_rand=ref_rand,
884880
)
@@ -902,7 +898,6 @@ def test_scalar_components(self):
902898
w=np.ones(npop) / npop,
903899
mu=mus,
904900
sigma=1e-5,
905-
comp_shape=(nd, npop),
906901
shape=nd,
907902
)
908903
z = Categorical("z", p=np.ones(npop) / npop, shape=nd)

0 commit comments

Comments
 (0)