|
1 |
| -from ..stats import * |
| 1 | +import pymc as pm |
| 2 | +from pymc import stats |
| 3 | +import numpy as np |
2 | 4 | from numpy.random import random, normal, seed
|
3 | 5 | from numpy.testing import assert_equal, assert_almost_equal, assert_array_almost_equal
|
| 6 | +import warnings |
| 7 | +import nose |
4 | 8 |
|
5 | 9 | seed(111)
|
6 | 10 | normal_sample = normal(0, 1, 1000000)
|
7 | 11 |
|
8 | 12 | def test_autocorr():
|
9 | 13 | """Test autocorrelation and autocovariance functions"""
|
10 | 14 |
|
11 |
| - assert_almost_equal(autocorr(normal_sample), 0, 2) |
| 15 | + assert_almost_equal(stats.autocorr(normal_sample), 0, 2) |
12 | 16 |
|
13 | 17 | y = [(normal_sample[i-1] + normal_sample[i])/2 for i in range(1, len(normal_sample))]
|
14 |
| - assert_almost_equal(autocorr(y), 0.5, 2) |
| 18 | + assert_almost_equal(stats.autocorr(y), 0.5, 2) |
15 | 19 |
|
16 | 20 | def test_hpd():
|
17 | 21 | """Test HPD calculation"""
|
18 | 22 |
|
19 |
| - interval = hpd(normal_sample) |
| 23 | + interval = stats.hpd(normal_sample) |
20 | 24 |
|
21 | 25 | assert_array_almost_equal(interval, [-1.96, 1.96], 2)
|
22 | 26 |
|
23 | 27 | def test_make_indices():
|
24 | 28 | """Test make_indices function"""
|
25 | 29 |
|
26 |
| - from ..stats import make_indices |
27 |
| - |
28 | 30 | ind = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]
|
29 | 31 |
|
30 |
| - assert_equal(ind, make_indices((2, 3))) |
| 32 | + assert_equal(ind, stats.make_indices((2, 3))) |
31 | 33 |
|
32 | 34 | def test_mc_error():
|
33 | 35 | """Test batch standard deviation function"""
|
34 | 36 |
|
35 | 37 | x = random(100000)
|
36 | 38 |
|
37 |
| - assert(mc_error(x) < 0.0025) |
| 39 | + assert(stats.mc_error(x) < 0.0025) |
38 | 40 |
|
39 | 41 | def test_quantiles():
|
40 | 42 | """Test quantiles function"""
|
41 | 43 |
|
42 |
| - q = quantiles(normal_sample) |
| 44 | + q = stats.quantiles(normal_sample) |
43 | 45 |
|
44 | 46 | assert_array_almost_equal(sorted(q.values()), [-1.96, -0.67, 0, 0.67, 1.96], 2)
|
| 47 | + |
| 48 | + |
| 49 | +def test_summary_1_value_model(): |
| 50 | + mu = -2.1 |
| 51 | + tau = 1.3 |
| 52 | + with pm.Model() as model: |
| 53 | + x = pm.Normal('x', mu, tau, testval=.1) |
| 54 | + step = pm.Metropolis(model.vars, np.diag([1.])) |
| 55 | + trace = pm.sample(100, step=step) |
| 56 | + stats.summary(trace) |
| 57 | + |
| 58 | + |
| 59 | +def test_summary_2_value_model(): |
| 60 | + mu = -2.1 |
| 61 | + tau = 1.3 |
| 62 | + with pm.Model() as model: |
| 63 | + x = pm.Normal('x', mu, tau, shape=2, testval=[.1, .1]) |
| 64 | + step = pm.Metropolis(model.vars, np.diag([1.])) |
| 65 | + trace = pm.sample(100, step=step) |
| 66 | + stats.summary(trace) |
| 67 | + |
| 68 | + |
| 69 | +def test_summary_2dim_value_model(): |
| 70 | + mu = -2.1 |
| 71 | + tau = 1.3 |
| 72 | + with pm.Model() as model: |
| 73 | + x = pm.Normal('x', mu, tau, shape=(2, 2), |
| 74 | + testval=np.tile(.1, (2, 2))) |
| 75 | + step = pm.Metropolis(model.vars, np.diag([1.])) |
| 76 | + trace = pm.sample(100, step=step) |
| 77 | + |
| 78 | + with warnings.catch_warnings(record=True) as wrn: |
| 79 | + stats.summary(trace) |
| 80 | + assert len(wrn) == 1 |
| 81 | + assert str(wrn[0].message) == 'Skipping x (above 1 dimension)' |
| 82 | + |
| 83 | + |
| 84 | +def test_summary_format_values(): |
| 85 | + roundto = 2 |
| 86 | + summ = stats._Summary(roundto) |
| 87 | + d = {'nodec': 1, 'onedec': 1.0, 'twodec': 1.00, 'threedec': 1.000} |
| 88 | + summ._format_values(d) |
| 89 | + for val in d.values(): |
| 90 | + assert val == '1.00' |
| 91 | + |
| 92 | + |
| 93 | +def test_stat_summary_format_hpd_values(): |
| 94 | + roundto = 2 |
| 95 | + summ = stats._StatSummary(roundto, None, 0.05) |
| 96 | + d = {'nodec': 1, 'hpd': [1, 1]} |
| 97 | + summ._format_values(d) |
| 98 | + for key, val in d.items(): |
| 99 | + if key == 'hpd': |
| 100 | + assert val == '[1.00, 1.00]' |
| 101 | + else: |
| 102 | + assert val == '1.00' |
| 103 | + |
| 104 | + |
| 105 | +@nose.tools.raises(IndexError) |
| 106 | +def test_calculate_stats_variable_size1_not_adjusted(): |
| 107 | + sample = np.arange(10) |
| 108 | + list(stats._calculate_stats(sample, 5, 0.05)) |
| 109 | + |
| 110 | + |
| 111 | +def test_calculate_stats_variable_size1_adjusted(): |
| 112 | + sample = np.arange(10)[:, None] |
| 113 | + result_size = len(list(stats._calculate_stats(sample, 5, 0.05))) |
| 114 | + assert result_size == 1 |
| 115 | + |
| 116 | +def test_calculate_stats_variable_size2(): |
| 117 | + ## 2 traces of 5 |
| 118 | + sample = np.arange(10).reshape(5, 2) |
| 119 | + result_size = len(list(stats._calculate_stats(sample, 5, 0.05))) |
| 120 | + assert result_size == 2 |
| 121 | + |
| 122 | + |
| 123 | +@nose.tools.raises(IndexError) |
| 124 | +def test_calculate_pquantiles_variable_size1_not_adjusted(): |
| 125 | + sample = np.arange(10) |
| 126 | + qlist = (0.25, 25, 50, 75, 0.98) |
| 127 | + list(stats._calculate_posterior_quantiles(sample, |
| 128 | + qlist)) |
| 129 | + |
| 130 | + |
| 131 | +def test_calculate_pquantiles_variable_size1_adjusted(): |
| 132 | + sample = np.arange(10)[:, None] |
| 133 | + qlist = (0.25, 25, 50, 75, 0.98) |
| 134 | + result_size = len(list(stats._calculate_posterior_quantiles(sample, |
| 135 | + qlist))) |
| 136 | + assert result_size == 1 |
| 137 | + |
| 138 | + |
| 139 | +def test_stats_value_line(): |
| 140 | + roundto = 1 |
| 141 | + summ = stats._StatSummary(roundto, None, 0.05) |
| 142 | + values = [{'mean': 0, 'sd': 1, 'mce': 2, 'hpd': [4, 4]}, |
| 143 | + {'mean': 5, 'sd': 6, 'mce': 7, 'hpd': [8, 8]},] |
| 144 | + |
| 145 | + expected = ['0.0 1.0 2.0 [4.0, 4.0]', |
| 146 | + '5.0 6.0 7.0 [8.0, 8.0]'] |
| 147 | + result = list(summ._create_value_output(values)) |
| 148 | + assert result == expected |
| 149 | + |
| 150 | + |
| 151 | +def test_post_quantile_value_line(): |
| 152 | + roundto = 1 |
| 153 | + summ = stats._PosteriorQuantileSummary(roundto, 0.05) |
| 154 | + values = [{'lo': 0, 'q25': 1, 'q50': 2, 'q75': 4, 'hi': 5}, |
| 155 | + {'lo': 6, 'q25': 7, 'q50': 8, 'q75': 9, 'hi': 10},] |
| 156 | + |
| 157 | + expected = ['0.0 1.0 2.0 4.0 5.0', |
| 158 | + '6.0 7.0 8.0 9.0 10.0'] |
| 159 | + result = list(summ._create_value_output(values)) |
| 160 | + assert result == expected |
| 161 | + |
| 162 | + |
| 163 | +def test_stats_output_lines(): |
| 164 | + roundto = 1 |
| 165 | + x = np.arange(10).reshape(5, 2) |
| 166 | + |
| 167 | + summ = stats._StatSummary(roundto, 5, 0.05) |
| 168 | + |
| 169 | + expected = [' Mean SD MC Error 95% HPD interval', |
| 170 | + ' -------------------------------------------------------------------', |
| 171 | + ' 4.0 2.8 1.3 [0.0, 8.0]', |
| 172 | + ' 5.0 2.8 1.3 [1.0, 9.0]',] |
| 173 | + result = list(summ._get_lines(x)) |
| 174 | + assert result == expected |
| 175 | + |
| 176 | + |
| 177 | +def test_posterior_quantiles_output_lines(): |
| 178 | + roundto = 1 |
| 179 | + x = np.arange(10).reshape(5, 2) |
| 180 | + |
| 181 | + summ = stats._PosteriorQuantileSummary(roundto, 0.05) |
| 182 | + |
| 183 | + expected = [' Posterior quantiles:', |
| 184 | + ' 2.5 25 50 75 97.5', |
| 185 | + ' |--------------|==============|==============|--------------|', |
| 186 | + ' 0.0 2.0 4.0 6.0 8.0', |
| 187 | + ' 1.0 3.0 5.0 7.0 9.0'] |
| 188 | + |
| 189 | + result = list(summ._get_lines(x)) |
| 190 | + assert result == expected |
0 commit comments