Skip to content

Commit 40a8070

Browse files
committed
Fix Travis Matplotlib errors with function imports
This commit makes two changes to avoid Matplotlib display errors from non-interactive shells. 1. Use Agg backend in test_plots.py. 2. Import pyplot within plotting functions. This allows plotting functions to still be imported into top-level __init__.py without selecting a another backend before Agg can be selected in tests. This also has the advantage that users running scripts in a non-interactive shell will be able to specify the backend after the top-level pymc import.
1 parent 607950e commit 40a8070

File tree

3 files changed

+41
-40
lines changed

3 files changed

+41
-40
lines changed

pymc/glm/glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import matplotlib.pyplot as plt
21
import numpy as np
32
from pymc import *
43
import patsy
@@ -179,6 +178,7 @@ def plot_posterior_predictive(trace, eval=None, lm=None, samples=30, **kwargs):
179178
Additional keyword arguments are passed to pylab.plot().
180179
181180
"""
181+
import matplotlib.pyplot as plt
182182

183183
if lm is None:
184184
lm = lambda x, sample: sample['Intercept'] + sample['x'] * x

pymc/plots.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
from pylab import *
2-
import matplotlib.pyplot as plt
3-
try:
4-
import matplotlib.gridspec as gridspec
5-
except ImportError:
6-
gridspec = None
71
import numpy as np
82
from scipy.stats import kde
93
from .stats import *
@@ -40,7 +34,7 @@ def traceplot(trace, vars=None, figsize=None,
4034
fig : figure object
4135
4236
"""
43-
37+
import matplotlib.pyplot as plt
4438
if vars is None:
4539
vars = trace.varnames
4640

@@ -138,7 +132,7 @@ def kde2plot(x, y, grid=200):
138132

139133
def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
140134
"""Bar plot of the autocorrelation function for a trace"""
141-
135+
import matplotlib.pyplot as plt
142136
try:
143137
# MultiTrace
144138
traces = trace.traces
@@ -159,7 +153,7 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
159153
chains = len(traces)
160154

161155
n = len(samples[0])
162-
f, ax = subplots(n, chains, squeeze=False)
156+
f, ax = plt.subplots(n, chains, squeeze=False)
163157

164158
max_lag = min(len(samples[0][vars[0]])-1, max_lag)
165159

@@ -169,7 +163,7 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
169163

170164
d = np.squeeze(samples[j][v])
171165

172-
ax[i, j].acorr(d, detrend=mlab.detrend_mean, maxlags=max_lag)
166+
ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
173167

174168
if not j:
175169
ax[i, j].set_ylabel("correlation")
@@ -179,11 +173,11 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
179173
ax[i, j].set_title("chain {0}".format(j+1))
180174

181175
# Smaller tick labels
182-
tlabels = gca().get_xticklabels()
183-
setp(tlabels, 'fontsize', fontmap[1])
176+
tlabels = plt.gca().get_xticklabels()
177+
plt.setp(tlabels, 'fontsize', fontmap[1])
184178

185-
tlabels = gca().get_yticklabels()
186-
setp(tlabels, 'fontsize', fontmap[1])
179+
tlabels = plt.gca().get_yticklabels()
180+
plt.setp(tlabels, 'fontsize', fontmap[1])
187181

188182

189183
def var_str(name, shape):
@@ -255,6 +249,11 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
255249
Location of vertical reference line (defaults to 0).
256250
257251
"""
252+
import matplotlib.pyplot as plt
253+
try:
254+
import matplotlib.gridspec as gridspec
255+
except ImportError:
256+
gridspec = None
258257

259258
if not gridspec:
260259
print_('\nYour installation of matplotlib is not recent enough to ' +
@@ -322,7 +321,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
322321
gs = gridspec.GridSpec(1, 1)
323322

324323
# Subplot for confidence intervals
325-
interval_plot = subplot(gs[0])
324+
interval_plot = plt.subplot(gs[0])
326325

327326
for j, tr in enumerate(traces):
328327

@@ -381,9 +380,9 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
381380

382381
if quartiles:
383382
# Plot median
384-
plot(q[2], y, 'bo', markersize=4)
383+
plt.plot(q[2], y, 'bo', markersize=4)
385384
# Plot quartile interval
386-
errorbar(
385+
plt.errorbar(
387386
x=(q[1],
388387
q[3]),
389388
y=(y,
@@ -393,10 +392,10 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
393392

394393
else:
395394
# Plot median
396-
plot(q[1], y, 'bo', markersize=4)
395+
plt.plot(q[1], y, 'bo', markersize=4)
397396

398397
# Plot outer interval
399-
errorbar(
398+
plt.errorbar(
400399
x=(q[0],
401400
q[-1]),
402401
y=(y,
@@ -411,9 +410,9 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
411410

412411
if quartiles:
413412
# Plot median
414-
plot(quants[2], y, 'bo', markersize=4)
413+
plt.plot(quants[2], y, 'bo', markersize=4)
415414
# Plot quartile interval
416-
errorbar(
415+
plt.errorbar(
417416
x=(quants[1],
418417
quants[3]),
419418
y=(y,
@@ -422,10 +421,10 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
422421
color='b')
423422
else:
424423
# Plot median
425-
plot(quants[1], y, 'bo', markersize=4)
424+
plt.plot(quants[1], y, 'bo', markersize=4)
426425

427426
# Plot outer interval
428-
errorbar(
427+
plt.errorbar(
429428
x=(quants[0],
430429
quants[-1]),
431430
y=(y,
@@ -443,27 +442,27 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
443442
gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)
444443

445444
# Define range of y-axis
446-
ylim(-var + 0.5, -0.5)
445+
plt.ylim(-var + 0.5, -0.5)
447446

448447
datarange = plotrange[1] - plotrange[0]
449-
xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)
448+
plt.xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)
450449

451450
# Add variable labels
452-
yticks([-(l + 1) for l in range(len(labels))], labels)
451+
plt.yticks([-(l + 1) for l in range(len(labels))], labels)
453452

454453
# Add title
455454
if main is not False:
456455
plot_title = main or str(int((
457456
1 - alpha) * 100)) + "% Credible Intervals"
458-
title(plot_title)
457+
plt.title(plot_title)
459458

460459
# Add x-axis label
461460
if xtitle is not None:
462-
xlabel(xtitle)
461+
plt.xlabel(xtitle)
463462

464463
# Constrain to specified range
465464
if xrange is not None:
466-
xlim(*xrange)
465+
plt.xlim(*xrange)
467466

468467
# Remove ticklines on y-axes
469468
for ticks in interval_plot.yaxis.get_major_ticks():
@@ -478,23 +477,23 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
478477
spine.set_color('none') # don't draw spine
479478

480479
# Reference line
481-
axvline(vline, color='k', linestyle='--')
480+
plt.axvline(vline, color='k', linestyle='--')
482481

483482
# Genenerate Gelman-Rubin plot
484483
if rhat and chains > 1:
485484

486485
# If there are multiple chains, calculate R-hat
487-
rhat_plot = subplot(gs[1])
486+
rhat_plot = plt.subplot(gs[1])
488487

489488
if main is not False:
490-
title("R-hat")
489+
plt.title("R-hat")
491490

492491
# Set x range
493-
xlim(0.9, 2.1)
492+
plt.xlim(0.9, 2.1)
494493

495494
# X axis labels
496-
xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
497-
yticks([-(l + 1) for l in range(len(labels))], "")
495+
plt.xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
496+
plt.yticks([-(l + 1) for l in range(len(labels))], "")
498497

499498
i = 1
500499
for varname in vars:
@@ -503,15 +502,15 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
503502
k = np.size(value)
504503

505504
if k > 1:
506-
plot([min(r, 2) for r in R[varname]], [-(j + i)
505+
plt.plot([min(r, 2) for r in R[varname]], [-(j + i)
507506
for j in range(k)], 'bo', markersize=4)
508507
else:
509-
plot(min(R[varname], 2), -i, 'bo', markersize=4)
508+
plt.plot(min(R[varname], 2), -i, 'bo', markersize=4)
510509

511510
i += k
512511

513512
# Define range of y-axis
514-
ylim(-i + 0.5, -0.5)
513+
plt.ylim(-i + 0.5, -0.5)
515514

516515
# Remove ticklines on y-axes
517516
for ticks in rhat_plot.yaxis.get_major_ticks():

pymc/tests/test_plots.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
#from ..plots import *
1+
import matplotlib
2+
matplotlib.use('Agg', warn=False)
3+
24
from pymc.plots import *
35
from pymc import psample, Slice, Metropolis, find_hessian, sample
46

0 commit comments

Comments
 (0)