1
- from pylab import *
2
- import matplotlib .pyplot as plt
3
- try :
4
- import matplotlib .gridspec as gridspec
5
- except ImportError :
6
- gridspec = None
7
1
import numpy as np
8
2
from scipy .stats import kde
9
3
from .stats import *
@@ -40,7 +34,7 @@ def traceplot(trace, vars=None, figsize=None,
40
34
fig : figure object
41
35
42
36
"""
43
-
37
+ import matplotlib . pyplot as plt
44
38
if vars is None :
45
39
vars = trace .varnames
46
40
@@ -138,7 +132,7 @@ def kde2plot(x, y, grid=200):
138
132
139
133
def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ):
140
134
"""Bar plot of the autocorrelation function for a trace"""
141
-
135
+ import matplotlib . pyplot as plt
142
136
try :
143
137
# MultiTrace
144
138
traces = trace .traces
@@ -159,7 +153,7 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
159
153
chains = len (traces )
160
154
161
155
n = len (samples [0 ])
162
- f , ax = subplots (n , chains , squeeze = False )
156
+ f , ax = plt . subplots (n , chains , squeeze = False )
163
157
164
158
max_lag = min (len (samples [0 ][vars [0 ]])- 1 , max_lag )
165
159
@@ -169,7 +163,7 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
169
163
170
164
d = np .squeeze (samples [j ][v ])
171
165
172
- ax [i , j ].acorr (d , detrend = mlab .detrend_mean , maxlags = max_lag )
166
+ ax [i , j ].acorr (d , detrend = plt . mlab .detrend_mean , maxlags = max_lag )
173
167
174
168
if not j :
175
169
ax [i , j ].set_ylabel ("correlation" )
@@ -179,11 +173,11 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
179
173
ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
180
174
181
175
# Smaller tick labels
182
- tlabels = gca ().get_xticklabels ()
183
- setp (tlabels , 'fontsize' , fontmap [1 ])
176
+ tlabels = plt . gca ().get_xticklabels ()
177
+ plt . setp (tlabels , 'fontsize' , fontmap [1 ])
184
178
185
- tlabels = gca ().get_yticklabels ()
186
- setp (tlabels , 'fontsize' , fontmap [1 ])
179
+ tlabels = plt . gca ().get_yticklabels ()
180
+ plt . setp (tlabels , 'fontsize' , fontmap [1 ])
187
181
188
182
189
183
def var_str (name , shape ):
@@ -255,6 +249,11 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
255
249
Location of vertical reference line (defaults to 0).
256
250
257
251
"""
252
+ import matplotlib .pyplot as plt
253
+ try :
254
+ import matplotlib .gridspec as gridspec
255
+ except ImportError :
256
+ gridspec = None
258
257
259
258
if not gridspec :
260
259
print_ ('\n Your installation of matplotlib is not recent enough to ' +
@@ -322,7 +321,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
322
321
gs = gridspec .GridSpec (1 , 1 )
323
322
324
323
# Subplot for confidence intervals
325
- interval_plot = subplot (gs [0 ])
324
+ interval_plot = plt . subplot (gs [0 ])
326
325
327
326
for j , tr in enumerate (traces ):
328
327
@@ -381,9 +380,9 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
381
380
382
381
if quartiles :
383
382
# Plot median
384
- plot (q [2 ], y , 'bo' , markersize = 4 )
383
+ plt . plot (q [2 ], y , 'bo' , markersize = 4 )
385
384
# Plot quartile interval
386
- errorbar (
385
+ plt . errorbar (
387
386
x = (q [1 ],
388
387
q [3 ]),
389
388
y = (y ,
@@ -393,10 +392,10 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
393
392
394
393
else :
395
394
# Plot median
396
- plot (q [1 ], y , 'bo' , markersize = 4 )
395
+ plt . plot (q [1 ], y , 'bo' , markersize = 4 )
397
396
398
397
# Plot outer interval
399
- errorbar (
398
+ plt . errorbar (
400
399
x = (q [0 ],
401
400
q [- 1 ]),
402
401
y = (y ,
@@ -411,9 +410,9 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
411
410
412
411
if quartiles :
413
412
# Plot median
414
- plot (quants [2 ], y , 'bo' , markersize = 4 )
413
+ plt . plot (quants [2 ], y , 'bo' , markersize = 4 )
415
414
# Plot quartile interval
416
- errorbar (
415
+ plt . errorbar (
417
416
x = (quants [1 ],
418
417
quants [3 ]),
419
418
y = (y ,
@@ -422,10 +421,10 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
422
421
color = 'b' )
423
422
else :
424
423
# Plot median
425
- plot (quants [1 ], y , 'bo' , markersize = 4 )
424
+ plt . plot (quants [1 ], y , 'bo' , markersize = 4 )
426
425
427
426
# Plot outer interval
428
- errorbar (
427
+ plt . errorbar (
429
428
x = (quants [0 ],
430
429
quants [- 1 ]),
431
430
y = (y ,
@@ -443,27 +442,27 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
443
442
gs .update (left = left_margin , right = 0.95 , top = 0.9 , bottom = 0.05 )
444
443
445
444
# Define range of y-axis
446
- ylim (- var + 0.5 , - 0.5 )
445
+ plt . ylim (- var + 0.5 , - 0.5 )
447
446
448
447
datarange = plotrange [1 ] - plotrange [0 ]
449
- xlim (plotrange [0 ] - 0.05 * datarange , plotrange [1 ] + 0.05 * datarange )
448
+ plt . xlim (plotrange [0 ] - 0.05 * datarange , plotrange [1 ] + 0.05 * datarange )
450
449
451
450
# Add variable labels
452
- yticks ([- (l + 1 ) for l in range (len (labels ))], labels )
451
+ plt . yticks ([- (l + 1 ) for l in range (len (labels ))], labels )
453
452
454
453
# Add title
455
454
if main is not False :
456
455
plot_title = main or str (int ((
457
456
1 - alpha ) * 100 )) + "% Credible Intervals"
458
- title (plot_title )
457
+ plt . title (plot_title )
459
458
460
459
# Add x-axis label
461
460
if xtitle is not None :
462
- xlabel (xtitle )
461
+ plt . xlabel (xtitle )
463
462
464
463
# Constrain to specified range
465
464
if xrange is not None :
466
- xlim (* xrange )
465
+ plt . xlim (* xrange )
467
466
468
467
# Remove ticklines on y-axes
469
468
for ticks in interval_plot .yaxis .get_major_ticks ():
@@ -478,23 +477,23 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
478
477
spine .set_color ('none' ) # don't draw spine
479
478
480
479
# Reference line
481
- axvline (vline , color = 'k' , linestyle = '--' )
480
+ plt . axvline (vline , color = 'k' , linestyle = '--' )
482
481
483
482
# Genenerate Gelman-Rubin plot
484
483
if rhat and chains > 1 :
485
484
486
485
# If there are multiple chains, calculate R-hat
487
- rhat_plot = subplot (gs [1 ])
486
+ rhat_plot = plt . subplot (gs [1 ])
488
487
489
488
if main is not False :
490
- title ("R-hat" )
489
+ plt . title ("R-hat" )
491
490
492
491
# Set x range
493
- xlim (0.9 , 2.1 )
492
+ plt . xlim (0.9 , 2.1 )
494
493
495
494
# X axis labels
496
- xticks ((1.0 , 1.5 , 2.0 ), ("1" , "1.5" , "2+" ))
497
- yticks ([- (l + 1 ) for l in range (len (labels ))], "" )
495
+ plt . xticks ((1.0 , 1.5 , 2.0 ), ("1" , "1.5" , "2+" ))
496
+ plt . yticks ([- (l + 1 ) for l in range (len (labels ))], "" )
498
497
499
498
i = 1
500
499
for varname in vars :
@@ -503,15 +502,15 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
503
502
k = np .size (value )
504
503
505
504
if k > 1 :
506
- plot ([min (r , 2 ) for r in R [varname ]], [- (j + i )
505
+ plt . plot ([min (r , 2 ) for r in R [varname ]], [- (j + i )
507
506
for j in range (k )], 'bo' , markersize = 4 )
508
507
else :
509
- plot (min (R [varname ], 2 ), - i , 'bo' , markersize = 4 )
508
+ plt . plot (min (R [varname ], 2 ), - i , 'bo' , markersize = 4 )
510
509
511
510
i += k
512
511
513
512
# Define range of y-axis
514
- ylim (- i + 0.5 , - 0.5 )
513
+ plt . ylim (- i + 0.5 , - 0.5 )
515
514
516
515
# Remove ticklines on y-axes
517
516
for ticks in rhat_plot .yaxis .get_major_ticks ():
0 commit comments