Skip to content

Commit 6dda7e1

Browse files
committed
Move summary to stats module
1 parent 3060c24 commit 6dda7e1

File tree

5 files changed

+319
-317
lines changed

5 files changed

+319
-317
lines changed

pymc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .trace import *
88
from .sample import *
9+
from .stats import summary
910
from .step_methods import *
1011
from .tuning import *
1112

pymc/stats.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Utility functions for PyMC"""
22

33
import numpy as np
4+
from .trace import MultiTrace
5+
import warnings
46

5-
__all__ = ['autocorr', 'autocov', 'hpd', 'quantiles', 'mc_error']
7+
8+
__all__ = ['autocorr', 'autocov', 'hpd', 'quantiles', 'mc_error', 'summary']
69

710
def statfunc(f):
811
"""
@@ -237,3 +240,161 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5)):
237240

238241
except IndexError:
239242
print("Too few elements for quantile calculation")
243+
244+
245+
def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3):
246+
"""
247+
Generate a pretty-printed summary of the node.
248+
249+
:Parameters:
250+
trace : Trace object
251+
Trace containing MCMC sample
252+
253+
vars : list of strings
254+
List of variables to summarize. Defaults to None, which results
255+
in all variables summarized.
256+
257+
alpha : float
258+
The alpha level for generating posterior intervals. Defaults to
259+
0.05.
260+
261+
start : int
262+
The starting index from which to summarize (each) chain. Defaults
263+
to zero.
264+
265+
batches : int
266+
Batch size for calculating standard deviation for non-independent
267+
samples. Defaults to 100.
268+
269+
roundto : int
270+
The number of digits to round posterior statistics.
271+
272+
"""
273+
if vars is None:
274+
vars = trace.varnames
275+
if isinstance(trace, MultiTrace):
276+
trace = trace.combined()
277+
278+
stat_summ = _StatSummary(roundto, batches, alpha)
279+
pq_summ = _PosteriorQuantileSummary(roundto, alpha)
280+
281+
for var in vars:
282+
# Extract sampled values
283+
sample = trace[var][start:]
284+
if sample.ndim == 1:
285+
sample = sample[:, None]
286+
elif sample.ndim > 2:
287+
## trace dimensions greater than 2 (variable greater than 1)
288+
warnings.warn('Skipping {} (above 1 dimension)'.format(var))
289+
continue
290+
291+
print('\n%s:' % var)
292+
print(' ')
293+
294+
stat_summ.print_output(sample)
295+
pq_summ.print_output(sample)
296+
297+
298+
class _Summary(object):
299+
"""Base class for summary output"""
300+
def __init__(self, roundto):
301+
self.roundto = roundto
302+
self.header_lines = None
303+
self.leader = ' '
304+
self.spaces = None
305+
306+
def print_output(self, sample):
307+
print('\n'.join(list(self._get_lines(sample))) + '\n')
308+
309+
def _get_lines(self, sample):
310+
for line in self.header_lines:
311+
yield self.leader + line
312+
summary_lines = self._calculate_values(sample)
313+
for line in self._create_value_output(summary_lines):
314+
yield self.leader + line
315+
316+
def _create_value_output(self, lines):
317+
for values in lines:
318+
self._format_values(values)
319+
yield self.value_line.format(pad=self.spaces, **values).strip()
320+
321+
def _calculate_values(self, sample):
322+
raise NotImplementedError
323+
324+
def _format_values(self, summary_values):
325+
for key, val in summary_values.items():
326+
summary_values[key] = '{:.{ndec}f}'.format(
327+
float(val), ndec=self.roundto)
328+
329+
330+
class _StatSummary(_Summary):
331+
def __init__(self, roundto, batches, alpha):
332+
super(_StatSummary, self).__init__(roundto)
333+
spaces = 17
334+
hpd_name = '{}% HPD interval'.format(int(100 * (1 - alpha)))
335+
value_line = '{mean:<{pad}}{sd:<{pad}}{mce:<{pad}}{hpd:<{pad}}'
336+
header = value_line.format(mean='Mean', sd='SD', mce='MC Error',
337+
hpd=hpd_name, pad=spaces).strip()
338+
hline = '-' * len(header)
339+
340+
self.header_lines = [header, hline]
341+
self.spaces = spaces
342+
self.value_line = value_line
343+
self.batches = batches
344+
self.alpha = alpha
345+
346+
def _calculate_values(self, sample):
347+
return _calculate_stats(sample, self.batches, self.alpha)
348+
349+
def _format_values(self, summary_values):
350+
roundto = self.roundto
351+
for key, val in summary_values.items():
352+
if key == 'hpd':
353+
summary_values[key] = '[{:.{ndec}f}, {:.{ndec}f}]'.format(
354+
*val, ndec=roundto)
355+
else:
356+
summary_values[key] = '{:.{ndec}f}'.format(
357+
float(val), ndec=roundto)
358+
359+
360+
class _PosteriorQuantileSummary(_Summary):
361+
def __init__(self, roundto, alpha):
362+
super(_PosteriorQuantileSummary, self).__init__(roundto)
363+
spaces = 15
364+
title = 'Posterior quantiles:'
365+
value_line = '{lo:<{pad}}{q25:<{pad}}{q50:<{pad}}{q75:<{pad}}{hi:<{pad}}'
366+
lo, hi = 100 * alpha / 2, 100 * (1. - alpha / 2)
367+
qlist = (lo, 25, 50, 75, hi)
368+
header = value_line.format(lo=lo, q25=25, q50=50, q75=75, hi=hi,
369+
pad=spaces).strip()
370+
hline = '|{thin}|{thick}|{thick}|{thin}|'.format(
371+
thin='-' * (spaces - 1), thick='=' * (spaces - 1))
372+
373+
self.header_lines = [title, header, hline]
374+
self.spaces = spaces
375+
self.lo, self.hi = lo, hi
376+
self.qlist = qlist
377+
self.value_line = value_line
378+
379+
def _calculate_values(self, sample):
380+
return _calculate_posterior_quantiles(sample, self.qlist)
381+
382+
383+
def _calculate_stats(sample, batches, alpha):
384+
means = sample.mean(0)
385+
sds = sample.std(0)
386+
mces = mc_error(sample, batches)
387+
intervals = hpd(sample, alpha)
388+
for index in range(sample.shape[1]):
389+
mean, sd, mce = [stat[index] for stat in (means, sds, mces)]
390+
interval = intervals[index].squeeze().tolist()
391+
yield {'mean': mean, 'sd': sd, 'mce': mce, 'hpd': interval}
392+
393+
394+
def _calculate_posterior_quantiles(sample, qlist):
395+
var_quantiles = quantiles(sample, qlist=qlist)
396+
## Replace ends of qlist with 'lo' and 'hi'
397+
qends = {qlist[0]: 'lo', qlist[-1]: 'hi'}
398+
qkeys = {q: qends[q] if q in qends else 'q{}'.format(q) for q in qlist}
399+
for index in range(sample.shape[1]):
400+
yield {qkeys[q]: var_quantiles[q][index] for q in qlist}

pymc/tests/test_stats.py

Lines changed: 155 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,190 @@
1-
from ..stats import *
1+
import pymc as pm
2+
from pymc import stats
3+
import numpy as np
24
from numpy.random import random, normal, seed
35
from numpy.testing import assert_equal, assert_almost_equal, assert_array_almost_equal
6+
import warnings
7+
import nose
48

59
seed(111)
610
normal_sample = normal(0, 1, 1000000)
711

812
def test_autocorr():
913
"""Test autocorrelation and autocovariance functions"""
1014

11-
assert_almost_equal(autocorr(normal_sample), 0, 2)
15+
assert_almost_equal(stats.autocorr(normal_sample), 0, 2)
1216

1317
y = [(normal_sample[i-1] + normal_sample[i])/2 for i in range(1, len(normal_sample))]
14-
assert_almost_equal(autocorr(y), 0.5, 2)
18+
assert_almost_equal(stats.autocorr(y), 0.5, 2)
1519

1620
def test_hpd():
1721
"""Test HPD calculation"""
1822

19-
interval = hpd(normal_sample)
23+
interval = stats.hpd(normal_sample)
2024

2125
assert_array_almost_equal(interval, [-1.96, 1.96], 2)
2226

2327
def test_make_indices():
2428
"""Test make_indices function"""
2529

26-
from ..stats import make_indices
27-
2830
ind = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]
2931

30-
assert_equal(ind, make_indices((2, 3)))
32+
assert_equal(ind, stats.make_indices((2, 3)))
3133

3234
def test_mc_error():
3335
"""Test batch standard deviation function"""
3436

3537
x = random(100000)
3638

37-
assert(mc_error(x) < 0.0025)
39+
assert(stats.mc_error(x) < 0.0025)
3840

3941
def test_quantiles():
4042
"""Test quantiles function"""
4143

42-
q = quantiles(normal_sample)
44+
q = stats.quantiles(normal_sample)
4345

4446
assert_array_almost_equal(sorted(q.values()), [-1.96, -0.67, 0, 0.67, 1.96], 2)
47+
48+
49+
def test_summary_1_value_model():
50+
mu = -2.1
51+
tau = 1.3
52+
with pm.Model() as model:
53+
x = pm.Normal('x', mu, tau, testval=.1)
54+
step = pm.Metropolis(model.vars, np.diag([1.]))
55+
trace = pm.sample(100, step=step)
56+
stats.summary(trace)
57+
58+
59+
def test_summary_2_value_model():
60+
mu = -2.1
61+
tau = 1.3
62+
with pm.Model() as model:
63+
x = pm.Normal('x', mu, tau, shape=2, testval=[.1, .1])
64+
step = pm.Metropolis(model.vars, np.diag([1.]))
65+
trace = pm.sample(100, step=step)
66+
stats.summary(trace)
67+
68+
69+
def test_summary_2dim_value_model():
70+
mu = -2.1
71+
tau = 1.3
72+
with pm.Model() as model:
73+
x = pm.Normal('x', mu, tau, shape=(2, 2),
74+
testval=np.tile(.1, (2, 2)))
75+
step = pm.Metropolis(model.vars, np.diag([1.]))
76+
trace = pm.sample(100, step=step)
77+
78+
with warnings.catch_warnings(record=True) as wrn:
79+
stats.summary(trace)
80+
assert len(wrn) == 1
81+
assert str(wrn[0].message) == 'Skipping x (above 1 dimension)'
82+
83+
84+
def test_summary_format_values():
85+
roundto = 2
86+
summ = stats._Summary(roundto)
87+
d = {'nodec': 1, 'onedec': 1.0, 'twodec': 1.00, 'threedec': 1.000}
88+
summ._format_values(d)
89+
for val in d.values():
90+
assert val == '1.00'
91+
92+
93+
def test_stat_summary_format_hpd_values():
94+
roundto = 2
95+
summ = stats._StatSummary(roundto, None, 0.05)
96+
d = {'nodec': 1, 'hpd': [1, 1]}
97+
summ._format_values(d)
98+
for key, val in d.items():
99+
if key == 'hpd':
100+
assert val == '[1.00, 1.00]'
101+
else:
102+
assert val == '1.00'
103+
104+
105+
@nose.tools.raises(IndexError)
106+
def test_calculate_stats_variable_size1_not_adjusted():
107+
sample = np.arange(10)
108+
list(stats._calculate_stats(sample, 5, 0.05))
109+
110+
111+
def test_calculate_stats_variable_size1_adjusted():
112+
sample = np.arange(10)[:, None]
113+
result_size = len(list(stats._calculate_stats(sample, 5, 0.05)))
114+
assert result_size == 1
115+
116+
def test_calculate_stats_variable_size2():
117+
## 2 traces of 5
118+
sample = np.arange(10).reshape(5, 2)
119+
result_size = len(list(stats._calculate_stats(sample, 5, 0.05)))
120+
assert result_size == 2
121+
122+
123+
@nose.tools.raises(IndexError)
124+
def test_calculate_pquantiles_variable_size1_not_adjusted():
125+
sample = np.arange(10)
126+
qlist = (0.25, 25, 50, 75, 0.98)
127+
list(stats._calculate_posterior_quantiles(sample,
128+
qlist))
129+
130+
131+
def test_calculate_pquantiles_variable_size1_adjusted():
132+
sample = np.arange(10)[:, None]
133+
qlist = (0.25, 25, 50, 75, 0.98)
134+
result_size = len(list(stats._calculate_posterior_quantiles(sample,
135+
qlist)))
136+
assert result_size == 1
137+
138+
139+
def test_stats_value_line():
140+
roundto = 1
141+
summ = stats._StatSummary(roundto, None, 0.05)
142+
values = [{'mean': 0, 'sd': 1, 'mce': 2, 'hpd': [4, 4]},
143+
{'mean': 5, 'sd': 6, 'mce': 7, 'hpd': [8, 8]},]
144+
145+
expected = ['0.0 1.0 2.0 [4.0, 4.0]',
146+
'5.0 6.0 7.0 [8.0, 8.0]']
147+
result = list(summ._create_value_output(values))
148+
assert result == expected
149+
150+
151+
def test_post_quantile_value_line():
152+
roundto = 1
153+
summ = stats._PosteriorQuantileSummary(roundto, 0.05)
154+
values = [{'lo': 0, 'q25': 1, 'q50': 2, 'q75': 4, 'hi': 5},
155+
{'lo': 6, 'q25': 7, 'q50': 8, 'q75': 9, 'hi': 10},]
156+
157+
expected = ['0.0 1.0 2.0 4.0 5.0',
158+
'6.0 7.0 8.0 9.0 10.0']
159+
result = list(summ._create_value_output(values))
160+
assert result == expected
161+
162+
163+
def test_stats_output_lines():
164+
roundto = 1
165+
x = np.arange(10).reshape(5, 2)
166+
167+
summ = stats._StatSummary(roundto, 5, 0.05)
168+
169+
expected = [' Mean SD MC Error 95% HPD interval',
170+
' -------------------------------------------------------------------',
171+
' 4.0 2.8 1.3 [0.0, 8.0]',
172+
' 5.0 2.8 1.3 [1.0, 9.0]',]
173+
result = list(summ._get_lines(x))
174+
assert result == expected
175+
176+
177+
def test_posterior_quantiles_output_lines():
178+
roundto = 1
179+
x = np.arange(10).reshape(5, 2)
180+
181+
summ = stats._PosteriorQuantileSummary(roundto, 0.05)
182+
183+
expected = [' Posterior quantiles:',
184+
' 2.5 25 50 75 97.5',
185+
' |--------------|==============|==============|--------------|',
186+
' 0.0 2.0 4.0 6.0 8.0',
187+
' 1.0 3.0 5.0 7.0 9.0']
188+
189+
result = list(summ._get_lines(x))
190+
assert result == expected

0 commit comments

Comments
 (0)