12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
-
16
15
import collections
17
16
import sys
18
17
from typing import Optional
19
18
20
19
import arviz as az
21
20
import blackjax
22
21
import jax
23
- import jax .numpy as jnp
24
- import jax .random as random
25
22
import numpy as np
26
23
import pymc as pm
27
- from pymc import modelcontext
24
+ from packaging import version
25
+ from pymc .backends .arviz import coords_and_dims_for_inferencedata
26
+ from pymc .blocking import DictToArrayBijection , RaveledVars
27
+ from pymc .model import modelcontext
28
28
from pymc .sampling .jax import get_jaxified_graph
29
29
from pymc .util import RandomSeed , _get_seeds_per_chain , get_default_varnames
30
30
31
31
32
32
def convert_flat_trace_to_idata (
33
33
samples ,
34
- dims = None ,
35
- coords = None ,
36
34
include_transformed = False ,
37
35
postprocessing_backend = "cpu" ,
38
36
model = None ,
39
37
):
40
38
41
39
model = modelcontext (model )
42
- init_position_dict = model .initial_point ()
40
+ ip = model .initial_point ()
41
+ ip_point_map_info = pm .blocking .DictToArrayBijection .map (ip ).point_map_info
43
42
trace = collections .defaultdict (list )
44
- astart = pm .blocking .DictToArrayBijection .map (init_position_dict )
45
43
for sample in samples :
46
- raveld_vars = pm . blocking . RaveledVars (sample , astart . point_map_info )
47
- point = pm . blocking . DictToArrayBijection .rmap (raveld_vars , init_position_dict )
44
+ raveld_vars = RaveledVars (sample , ip_point_map_info )
45
+ point = DictToArrayBijection .rmap (raveld_vars , ip )
48
46
for p , v in point .items ():
49
47
trace [p ].append (v .tolist ())
50
48
@@ -57,19 +55,19 @@ def convert_flat_trace_to_idata(
57
55
result = jax .vmap (jax .vmap (jax_fn ))(
58
56
* jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
59
57
)
60
-
61
58
trace = {v .name : r for v , r in zip (vars_to_sample , result )}
59
+ coords , dims = coords_and_dims_for_inferencedata (model )
62
60
idata = az .from_dict (trace , dims = dims , coords = coords )
63
61
64
62
return idata
65
63
66
64
67
65
def fit_pathfinder (
68
- iterations = 5_000 ,
66
+ samples = 1000 ,
69
67
random_seed : Optional [RandomSeed ] = None ,
70
68
postprocessing_backend = "cpu" ,
71
- ftol = 1e-4 ,
72
69
model = None ,
70
+ ** pathfinder_kwargs ,
73
71
):
74
72
"""
75
73
Fit the pathfinder algorithm as implemented in blackjax
@@ -78,15 +76,15 @@ def fit_pathfinder(
78
76
79
77
Parameters
80
78
----------
81
- iterations : int
82
- Number of iterations to run .
79
+ samples : int
80
+ Number of samples to draw from the fitted approximation .
83
81
random_seed : int
84
82
Random seed to set.
85
83
postprocessing_backend : str
86
84
Where to compute transformations of the trace.
87
85
"cpu" or "gpu".
88
- ftol : float
89
- Floating point tolerance
86
+ pathfinder_kwargs:
87
+ kwargs for blackjax.vi.pathfinder.approximate
90
88
91
89
Returns
92
90
-------
@@ -96,53 +94,42 @@ def fit_pathfinder(
96
94
---------
97
95
https://arxiv.org/abs/2108.03782
98
96
"""
99
-
100
- (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
97
+ # Temporarily helper
98
+ if version .parse (blackjax .__version__ ).major < 1 :
99
+ raise ImportError ("fit_pathfinder requires blackjax 1.0 or above" )
101
100
102
101
model = modelcontext (model )
103
102
104
- rvs = [rv .name for rv in model .value_vars ]
105
- init_position_dict = model .initial_point ()
106
- init_position = [init_position_dict [rv ] for rv in rvs ]
103
+ ip = model .initial_point ()
104
+ ip_map = DictToArrayBijection .map (ip )
107
105
108
106
new_logprob , new_input = pm .pytensorf .join_nonshared_inputs (
109
- init_position_dict , (model .logp (),), model .value_vars , ()
107
+ ip , (model .logp (),), model .value_vars , ()
110
108
)
111
109
112
110
logprob_fn_list = get_jaxified_graph ([new_input ], new_logprob )
113
111
114
112
def logprob_fn (x ):
115
113
return logprob_fn_list (x )[0 ]
116
114
117
- dim = sum (v .size for v in init_position_dict .values ())
118
-
119
- rng_key = random .PRNGKey (random_seed )
120
- w0 = random .multivariate_normal (rng_key , 2.0 + jnp .zeros (dim ), jnp .eye (dim ))
121
- path = blackjax .vi .pathfinder .init (rng_key , logprob_fn , w0 , return_path = True , ftol = ftol )
122
-
123
- pathfinder = blackjax .kernels .pathfinder (rng_key , logprob_fn , ftol = ftol )
124
- state = pathfinder .init (w0 )
125
-
126
- def inference_loop (rng_key , kernel , initial_state , num_samples ):
127
- @jax .jit
128
- def one_step (state , rng_key ):
129
- state , info = kernel (rng_key , state )
130
- return state , (state , info )
115
+ [pathfinder_seed , sample_seed ] = _get_seeds_per_chain (random_seed , 2 )
131
116
132
- keys = jax .random .split (rng_key , num_samples )
133
- return jax .lax .scan (one_step , initial_state , keys )
134
-
135
- _ , rng_key = random .split (rng_key )
136
117
print ("Running pathfinder..." , file = sys .stdout )
137
- _ , (_ , samples ) = inference_loop (rng_key , pathfinder .step , state , iterations )
138
-
139
- dims = {
140
- var_name : [dim for dim in dims if dim is not None ]
141
- for var_name , dims in model .named_vars_to_dims .items ()
142
- }
118
+ pathfinder_state , _ = blackjax .vi .pathfinder .approximate (
119
+ rng_key = jax .random .key (pathfinder_seed ),
120
+ logdensity_fn = logprob_fn ,
121
+ initial_position = ip_map .data ,
122
+ ** pathfinder_kwargs ,
123
+ )
124
+ samples , _ = blackjax .vi .pathfinder .sample (
125
+ rng_key = jax .random .key (sample_seed ),
126
+ state = pathfinder_state ,
127
+ num_samples = samples ,
128
+ )
143
129
144
130
idata = convert_flat_trace_to_idata (
145
- samples , postprocessing_backend = postprocessing_backend , coords = model .coords , dims = dims
131
+ samples ,
132
+ postprocessing_backend = postprocessing_backend ,
133
+ model = model ,
146
134
)
147
-
148
135
return idata
0 commit comments