Skip to content

Commit cc90212

Browse files
Allow compile_kwargs in sample_smc (#7702)
1 parent 3ccff92 commit cc90212

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

pymc/smc/kernels.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
model=None,
135135
random_seed=None,
136136
threshold=0.5,
137+
compile_kwargs: dict | None = None,
137138
):
138139
"""
139140
Initialize the SMC_kernel class.
@@ -154,6 +155,8 @@ def __init__(
154155
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
155156
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
156157
It should be between 0 and 1.
158+
compile_kwargs: dict, optional
159+
Keyword arguments passed to pytensor.function
157160
158161
Attributes
159162
----------
@@ -172,8 +175,8 @@ def __init__(
172175
self.model = modelcontext(model)
173176
self.variables = self.model.value_vars
174177

175-
self.var_info = {}
176-
self.tempered_posterior = None
178+
self.var_info: dict[str, tuple] = {}
179+
self.tempered_posterior: np.ndarray
177180
self.prior_logp = None
178181
self.likelihood_logp = None
179182
self.tempered_posterior_logp = None
@@ -184,6 +187,7 @@ def __init__(
184187
self.iteration = 0
185188
self.resampling_indexes = None
186189
self.weights = np.ones(self.draws) / self.draws
190+
self.compile_kwargs = compile_kwargs if compile_kwargs is not None else {}
187191

188192
def initialize_population(self) -> dict[str, np.ndarray]:
189193
"""Create an initial population from the prior distribution."""
@@ -239,10 +243,10 @@ def _initialize_kernel(self):
239243
shared = make_shared_replacements(initial_point, self.variables, self.model)
240244

241245
self.prior_logp_func = _logp_forw(
242-
initial_point, [self.model.varlogp], self.variables, shared
246+
initial_point, [self.model.varlogp], self.variables, shared, self.compile_kwargs
243247
)
244248
self.likelihood_logp_func = _logp_forw(
245-
initial_point, [self.model.datalogp], self.variables, shared
249+
initial_point, [self.model.datalogp], self.variables, shared, self.compile_kwargs
246250
)
247251

248252
priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior]
@@ -606,7 +610,7 @@ def systematic_resampling(weights, rng):
606610
return new_indices
607611

608612

609-
def _logp_forw(point, out_vars, in_vars, shared):
613+
def _logp_forw(point, out_vars, in_vars, shared, compile_kwargs=None):
610614
"""Compile PyTensor function of the model and the input and output variables.
611615
612616
Parameters
@@ -617,7 +621,12 @@ def _logp_forw(point, out_vars, in_vars, shared):
617621
Containing Distribution for the input variables
618622
shared : list
619623
Containing TensorVariable for depended shared data
624+
compile_kwargs: dict, optional
625+
Additional keyword arguments passed to pytensor.function
620626
"""
627+
if compile_kwargs is None:
628+
compile_kwargs = {}
629+
621630
# Replace integer inputs with rounded float inputs
622631
if any(var.dtype in discrete_types for var in in_vars):
623632
replace_int_input = {}
@@ -636,6 +645,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
636645
out_list, inarray0 = join_nonshared_inputs(
637646
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
638647
)
639-
f = compile([inarray0], out_list[0])
648+
f = compile([inarray0], out_list[0], **compile_kwargs)
640649
f.trust_input = True
641650
return f

pymc/smc/sampling.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def sample_smc(
5858
return_inferencedata=True,
5959
idata_kwargs=None,
6060
progressbar=True,
61+
compile_kwargs: dict | None = None,
6162
**kernel_kwargs,
6263
) -> InferenceData | MultiTrace:
6364
r"""
@@ -95,17 +96,21 @@ def sample_smc(
9596
Keyword arguments for :func:`pymc.to_inference_data`.
9697
progressbar : bool, optional, default True
9798
Whether or not to display a progress bar in the command line.
99+
compile_kwargs: dict, optional
100+
Keyword arguments to pass to pytensor.function
101+
98102
**kernel_kwargs : dict, optional
99103
Keyword arguments passed to the SMC_kernel. The default IMH kernel takes the following keywords:
100104
101105
threshold : float, default 0.5
102106
Determines the change of beta from stage to stage, i.e. indirectly the number of stages,
103107
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
104108
It should be between 0 and 1.
105-
correlation_threshold : float, default 0.01
106-
The lower the value the higher the number of MCMC steps computed automatically.
107-
Defaults to 0.01. It should be between 0 and 1.
108-
Keyword arguments for other kernels should be checked in the respective docstrings.
109+
correlation_threshold : float, default 0.01
110+
The lower the value the higher the number of MCMC steps computed automatically.
111+
Defaults to 0.01. It should be between 0 and 1.
112+
113+
Additional keyword arguments for other kernels should be checked in the respective docstrings.
109114
110115
Notes
111116
-----
@@ -160,6 +165,11 @@ def sample_smc(
160165
else:
161166
cores = min(chains, cores)
162167

168+
if compile_kwargs is None:
169+
compile_kwargs = {}
170+
171+
kernel_kwargs["compile_kwargs"] = compile_kwargs
172+
163173
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)
164174

165175
model = modelcontext(model)

0 commit comments

Comments
 (0)