Skip to content

Commit 22e8f0b

Browse files
authored
Replace fastprogress progress bars with rich (#7233)
* Replace fastprogress with rich * Bugfixes for ADVI progress bars * Bugfixes for MAP progress bars * Fixed final update to progress bar * SMC progress bar working * Fixes to MAP progress bar * Customize progress bar theme * Added progressbar_theme argument * Moved default progressbar theme to util * Convert compute_log_density to use Progress instead of track * Getting rid of mypy complaint
1 parent 207821d commit 22e8f0b

17 files changed

+372
-299
lines changed

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12-
- fastprogress>=0.2.0
1312
- h5py>=2.7
1413
- numpy>=1.15.0
1514
- pandas>=0.24.0
@@ -28,6 +27,7 @@ dependencies:
2827
- pre-commit>=2.8.0
2928
- pytest-cov>=2.5
3029
- pytest>=3.0
30+
- rich>=13.7.1
3131
- sphinx-copybutton
3232
- sphinx-design
3333
- sphinx-notfound-page

conda-envs/environment-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ dependencies:
88
- arviz>=0.13.0
99
- cachetools>=4.2.1
1010
- cloudpickle
11-
- fastprogress>=0.2.0
1211
- numpy>=1.15.0
1312
- pandas>=0.24.0
1413
- pip
1514
- pytensor>=2.19,<2.20
1615
- python-graphviz
16+
- rich>=13.7.1
1717
- scipy>=1.4.1
1818
- typing-extensions>=3.7.4
1919
# Extra dependencies for docs build

conda-envs/environment-jax.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12-
- fastprogress>=0.2.0
1312
- h5py>=2.7
1413
# Jaxlib version must not be greater than jax version!
1514
- blackjax>=1.0.0
@@ -24,6 +23,7 @@ dependencies:
2423
- pytensor>=2.19,<2.20
2524
- python-graphviz
2625
- networkx
26+
- rich>=13.7.1
2727
- scipy>=1.4.1
2828
- typing-extensions>=3.7.4
2929
# Extra dependencies for testing

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12-
- fastprogress>=0.2.0
1312
- h5py>=2.7
1413
- jax
1514
- libblas=*=*mkl
@@ -20,6 +19,7 @@ dependencies:
2019
- pytensor>=2.19,<2.20
2120
- python-graphviz
2221
- networkx
22+
- rich>=13.7.1
2323
- scipy>=1.4.1
2424
- typing-extensions>=3.7.4
2525
# Extra dependencies for testing

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12-
- fastprogress>=0.2.0
1312
- h5py>=2.7
1413
- numpy>=1.15.0
1514
- pandas>=0.24.0
1615
- pip
1716
- pytensor>=2.19,<2.20
1817
- python-graphviz
1918
- networkx
19+
- rich>=13.7.1
2020
- scipy>=1.4.1
2121
- typing-extensions>=3.7.4
2222
# Extra dependencies for dev, testing and docs build

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12-
- fastprogress>=0.2.0
1312
- h5py>=2.7
1413
- libpython
1514
- mkl-service>=2.3.0
@@ -20,6 +19,7 @@ dependencies:
2019
- pytensor>=2.19,<2.20
2120
- python-graphviz
2221
- networkx
22+
- rich>=13.7.1
2323
- scipy>=1.4.1
2424
- typing-extensions>=3.7.4
2525
# Extra dependencies for testing

pymc/sampling/forward.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import xarray
3131

3232
from arviz import InferenceData
33-
from fastprogress.fastprogress import progress_bar
3433
from pytensor import tensor as pt
3534
from pytensor.graph.basic import (
3635
Apply,
@@ -46,6 +45,9 @@
4645
RandomStateSharedVariable,
4746
)
4847
from pytensor.tensor.sharedvar import SharedVariable
48+
from rich.console import Console
49+
from rich.progress import Progress
50+
from rich.theme import Theme
4951
from typing_extensions import TypeAlias
5052

5153
import pymc as pm
@@ -59,6 +61,7 @@
5961
RandomState,
6062
_get_seeds_per_chain,
6163
dataset_to_point_list,
64+
default_progress_theme,
6265
get_default_varnames,
6366
point_wrapper,
6467
)
@@ -70,7 +73,6 @@
7073
"sample_posterior_predictive",
7174
)
7275

73-
7476
ArrayLike: TypeAlias = Union[np.ndarray, list[float]]
7577
PointList: TypeAlias = list[PointType]
7678

@@ -442,6 +444,7 @@ def sample_posterior_predictive(
442444
sample_dims: Optional[list[str]] = None,
443445
random_seed: RandomState = None,
444446
progressbar: bool = True,
447+
progressbar_theme: Optional[Theme] = default_progress_theme,
445448
return_inferencedata: bool = True,
446449
extend_inferencedata: bool = False,
447450
predictions: bool = False,
@@ -796,10 +799,6 @@ def sample_posterior_predictive(
796799
else:
797800
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
798801

799-
indices = np.arange(samples)
800-
if progressbar:
801-
indices = progress_bar(indices, total=samples, display=progressbar)
802-
803802
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
804803

805804
if not vars_to_sample:
@@ -834,25 +833,30 @@ def sample_posterior_predictive(
834833
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
835834
ppc_trace_t = _DefaultTrace(samples)
836835
try:
837-
for idx in indices:
838-
if nchain > 1:
839-
# the trace object will either be a MultiTrace (and have _straces)...
840-
if hasattr(_trace, "_straces"):
841-
chain_idx, point_idx = np.divmod(idx, len_trace)
842-
chain_idx = chain_idx % nchain
843-
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
844-
# ... or a PointList
836+
with Progress(console=Console(theme=progressbar_theme)) as progress:
837+
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
838+
for idx in np.arange(samples):
839+
if nchain > 1:
840+
# the trace object will either be a MultiTrace (and have _straces)...
841+
if hasattr(_trace, "_straces"):
842+
chain_idx, point_idx = np.divmod(idx, len_trace)
843+
chain_idx = chain_idx % nchain
844+
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
845+
# ... or a PointList
846+
else:
847+
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
848+
# there's only a single chain, but the index might hit it multiple times if
849+
# the number of indices is greater than the length of the trace.
845850
else:
846-
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
847-
# there's only a single chain, but the index might hit it multiple times if
848-
# the number of indices is greater than the length of the trace.
849-
else:
850-
param = _trace[idx % len_trace]
851+
param = _trace[idx % len_trace]
852+
853+
values = sampler_fn(**param)
854+
855+
for k, v in zip(vars_, values):
856+
ppc_trace_t.insert(k.name, v, idx)
851857

852-
values = sampler_fn(**param)
858+
progress.advance(task)
853859

854-
for k, v in zip(vars_, values):
855-
ppc_trace_t.insert(k.name, v, idx)
856860
except KeyboardInterrupt:
857861
pass
858862

pymc/sampling/mcmc.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434

3535
from arviz import InferenceData, dict_to_dataset
3636
from arviz.data.base import make_attrs
37-
from fastprogress.fastprogress import progress_bar
3837
from pytensor.graph.basic import Variable
38+
from rich.console import Console
39+
from rich.progress import Progress
40+
from rich.theme import Theme
3941
from typing_extensions import Protocol, TypeAlias
4042

4143
import pymc as pm
@@ -65,6 +67,7 @@
6567
RandomSeed,
6668
RandomState,
6769
_get_seeds_per_chain,
70+
default_progress_theme,
6871
drop_warning_stat,
6972
get_untransformed_name,
7073
is_transformed_name,
@@ -377,6 +380,7 @@ def sample(
377380
cores: Optional[int] = None,
378381
random_seed: RandomState = None,
379382
progressbar: bool = True,
383+
progressbar_theme: Optional[Theme] = default_progress_theme,
380384
step=None,
381385
var_names: Optional[Sequence[str]] = None,
382386
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
@@ -406,6 +410,7 @@ def sample(
406410
cores: Optional[int] = None,
407411
random_seed: RandomState = None,
408412
progressbar: bool = True,
413+
progressbar_theme: Optional[Theme] = default_progress_theme,
409414
step=None,
410415
var_names: Optional[Sequence[str]] = None,
411416
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
@@ -435,6 +440,7 @@ def sample(
435440
cores: Optional[int] = None,
436441
random_seed: RandomState = None,
437442
progressbar: bool = True,
443+
progressbar_theme: Optional[Theme] = default_progress_theme,
438444
step=None,
439445
var_names: Optional[Sequence[str]] = None,
440446
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
@@ -761,6 +767,7 @@ def sample(
761767
"tune": tune,
762768
"var_names": var_names,
763769
"progressbar": progressbar,
770+
"progressbar_theme": progressbar_theme,
764771
"model": model,
765772
"cores": cores,
766773
"callback": callback,
@@ -983,6 +990,7 @@ def _sample(
983990
trace: IBaseTrace,
984991
tune: int,
985992
model: Optional[Model] = None,
993+
progressbar_theme: Optional[Theme] = default_progress_theme,
986994
callback=None,
987995
**kwargs,
988996
) -> None:
@@ -1010,6 +1018,8 @@ def _sample(
10101018
tune : int
10111019
Number of iterations to tune.
10121020
model : Model (optional if in ``with`` context)
1021+
progressbar_theme : Theme
1022+
Optional custom theme for the progress bar.
10131023
"""
10141024
skip_first = kwargs.get("skip_first", 0)
10151025

@@ -1026,19 +1036,16 @@ def _sample(
10261036
)
10271037
_pbar_data = {"chain": chain, "divergences": 0}
10281038
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1029-
if progressbar:
1030-
sampling = progress_bar(sampling_gen, total=draws, display=progressbar)
1031-
sampling.comment = _desc.format(**_pbar_data)
1032-
else:
1033-
sampling = sampling_gen
1034-
try:
1035-
for it, diverging in enumerate(sampling):
1036-
if it >= skip_first and diverging:
1037-
_pbar_data["divergences"] += 1
1038-
if progressbar:
1039-
sampling.comment = _desc.format(**_pbar_data)
1040-
except KeyboardInterrupt:
1041-
pass
1039+
with Progress(console=Console(theme=progressbar_theme)) as progress:
1040+
try:
1041+
task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar)
1042+
for it, diverging in enumerate(sampling_gen):
1043+
if it >= skip_first and diverging:
1044+
_pbar_data["divergences"] += 1
1045+
progress.update(task, advance=1)
1046+
progress.update(task, advance=1, completed=True)
1047+
except KeyboardInterrupt:
1048+
pass
10421049

10431050

10441051
def _iter_sample(
@@ -1131,6 +1138,7 @@ def _mp_sample(
11311138
random_seed: Sequence[RandomSeed],
11321139
start: Sequence[PointType],
11331140
progressbar: bool = True,
1141+
progressbar_theme: Optional[Theme] = default_progress_theme,
11341142
traces: Sequence[IBaseTrace],
11351143
model: Optional[Model] = None,
11361144
callback: Optional[SamplingIteratorCallback] = None,
@@ -1158,6 +1166,8 @@ def _mp_sample(
11581166
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
11591167
progressbar : bool
11601168
Whether or not to display a progress bar in the command line.
1169+
progressbar_theme : Theme
1170+
Optional custom theme for the progress bar.
11611171
traces
11621172
Recording backends for each chain.
11631173
model : Model (optional if in ``with`` context)
@@ -1182,6 +1192,7 @@ def _mp_sample(
11821192
start_points=start,
11831193
step_method=step,
11841194
progressbar=progressbar,
1195+
progressbar_theme=progressbar_theme,
11851196
mp_ctx=mp_ctx,
11861197
)
11871198
try:

0 commit comments

Comments
 (0)