Skip to content

Commit 07cc7b6

Browse files
committed
Use float time counters for nuts stats
1 parent 334dfe1 commit 07cc7b6

File tree

4 files changed

+27
-23
lines changed

4 files changed

+27
-23
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def _hamiltonian_step(self, start, p0, step_size):
133133

134134
def astep(self, q0):
135135
"""Perform a single HMC iteration."""
136-
perf_start = time.perf_counter_ns()
137-
process_start = time.process_time_ns()
136+
perf_start = time.perf_counter()
137+
process_start = time.process_time()
138138

139139
p0 = self.potential.random()
140140
start = self.integrator.compute_state(q0, p0)
@@ -170,8 +170,8 @@ def astep(self, q0):
170170

171171
hmc_step = self._hamiltonian_step(start, p0, step_size)
172172

173-
perf_end = time.perf_counter_ns()
174-
process_end = time.process_time_ns()
173+
perf_end = time.perf_counter()
174+
process_end = time.process_time()
175175

176176
self.step_adapt.update(hmc_step.accept_stat, adapt_step)
177177
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
@@ -201,9 +201,9 @@ def astep(self, q0):
201201
stats = {
202202
"tune": self.tune,
203203
"diverging": bool(hmc_step.divergence_info),
204-
"perf_counter_diff_ns": perf_end - perf_start,
205-
"process_time_diff_ns": process_end - process_start,
206-
"perf_counter_ns": perf_end,
204+
"perf_counter_diff": perf_end - perf_start,
205+
"process_time_diff": process_end - process_start,
206+
"perf_counter_start": perf_start,
207207
}
208208

209209
stats.update(hmc_step.stats)

pymc3/step_methods/hmc/hmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class HamiltonianMC(BaseHMC):
4848
'path_length': np.float64,
4949
'accepted': np.bool,
5050
'model_logp': np.float64,
51-
'process_time_diff_ns': np.int64,
52-
'perf_counter_diff_ns': np.int64,
53-
'perf_counter_ns': np.int64,
51+
'process_time_diff': np.float64,
52+
'perf_counter_diff': np.float64,
53+
'perf_counter_start': np.float64,
5454
}]
5555

5656
def __init__(self, vars=None, path_length=2., max_steps=1024, **kwargs):

pymc3/step_methods/hmc/nuts.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ class NUTS(BaseHMC):
7272
samples, the step size is set to this value. This should converge
7373
during tuning.
7474
- `model_logp`: The model log-likelihood for this sample.
75-
- `process_time_diff_ns`: The time it took to draw the sample, as defined
76-
by the python standard library `time.process_time_ns`. This counts all
75+
- `process_time_diff`: The time it took to draw the sample, as defined
76+
by the python standard library `time.process_time`. This counts all
7777
the CPU time, including worker processes in BLAS and OpenMP.
78-
- `perf_counter_diff_ns`: The time it took to draw the sample, as defined
79-
by the python standard library `time.perf_counter_ns` (wall time).
80-
- `perf_counter_ns`: The value of the `time.perf_counter_ns` after drawing
81-
the sample.
78+
- `perf_counter_diff`: The time it took to draw the sample, as defined
79+
by the python standard library `time.perf_counter` (wall time).
80+
- `perf_counter_start`: The value of `time.perf_counter` at the beginning
81+
of the computation of the draw.
8282
8383
References
8484
----------
@@ -103,9 +103,9 @@ class NUTS(BaseHMC):
103103
"energy": np.float64,
104104
"max_energy_error": np.float64,
105105
"model_logp": np.float64,
106-
"process_time_diff_ns": np.int64,
107-
"perf_counter_diff_ns": np.int64,
108-
"perf_counter_ns": np.int64,
106+
"process_time_diff": np.float64,
107+
"perf_counter_diff": np.float64,
108+
"perf_counter_start": np.float64,
109109
}
110110
]
111111

pymc3/tests/test_step.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def test_linalg(self, caplog):
979979

980980
def test_sampler_stats(self):
981981
with Model() as model:
982-
x = Normal("x", mu=0, sigma=1)
982+
Normal("x", mu=0, sigma=1)
983983
trace = sample(draws=10, tune=1, chains=1)
984984

985985
# Assert stats exist and have the correct shape.
@@ -995,14 +995,18 @@ def test_sampler_stats(self):
995995
"step_size_bar",
996996
"tree_size",
997997
"tune",
998+
"perf_counter_diff",
999+
"perf_counter_start",
1000+
"process_time_diff",
9981001
}
9991002
assert trace.stat_names == expected_stat_names
10001003
for varname in trace.stat_names:
10011004
assert trace.get_sampler_stats(varname).shape == (10,)
10021005

10031006
# Assert model logp is computed correctly: computing post-sampling
10041007
# and tracking while sampling should give same results.
1005-
model_logp_ = np.array(
1006-
[model.logp(trace.point(i, chain=c)) for c in trace.chains for i in range(len(trace))]
1007-
)
1008+
model_logp_ = np.array([
1009+
model.logp(trace.point(i, chain=c))
1010+
for c in trace.chains for i in range(len(trace))
1011+
])
10081012
assert (trace.model_logp == model_logp_).all()

0 commit comments

Comments
 (0)