@@ -134,6 +134,7 @@ def __init__(
134
134
model = None ,
135
135
random_seed = None ,
136
136
threshold = 0.5 ,
137
+ compile_kwargs : dict | None = None ,
137
138
):
138
139
"""
139
140
Initialize the SMC_kernel class.
@@ -154,6 +155,8 @@ def __init__(
154
155
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
155
156
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
156
157
It should be between 0 and 1.
158
+ compile_kwargs: dict, optional
159
+ Keyword arguments passed to pytensor.function
157
160
158
161
Attributes
159
162
----------
@@ -172,8 +175,8 @@ def __init__(
172
175
self .model = modelcontext (model )
173
176
self .variables = self .model .value_vars
174
177
175
- self .var_info = {}
176
- self .tempered_posterior = None
178
+ self .var_info : dict [ str , tuple ] = {}
179
+ self .tempered_posterior : np . ndarray
177
180
self .prior_logp = None
178
181
self .likelihood_logp = None
179
182
self .tempered_posterior_logp = None
@@ -184,6 +187,7 @@ def __init__(
184
187
self .iteration = 0
185
188
self .resampling_indexes = None
186
189
self .weights = np .ones (self .draws ) / self .draws
190
+ self .compile_kwargs = compile_kwargs if compile_kwargs is not None else {}
187
191
188
192
def initialize_population (self ) -> dict [str , np .ndarray ]:
189
193
"""Create an initial population from the prior distribution."""
@@ -239,10 +243,10 @@ def _initialize_kernel(self):
239
243
shared = make_shared_replacements (initial_point , self .variables , self .model )
240
244
241
245
self .prior_logp_func = _logp_forw (
242
- initial_point , [self .model .varlogp ], self .variables , shared
246
+ initial_point , [self .model .varlogp ], self .variables , shared , self . compile_kwargs
243
247
)
244
248
self .likelihood_logp_func = _logp_forw (
245
- initial_point , [self .model .datalogp ], self .variables , shared
249
+ initial_point , [self .model .datalogp ], self .variables , shared , self . compile_kwargs
246
250
)
247
251
248
252
priors = [self .prior_logp_func (sample ) for sample in self .tempered_posterior ]
@@ -606,7 +610,7 @@ def systematic_resampling(weights, rng):
606
610
return new_indices
607
611
608
612
609
- def _logp_forw (point , out_vars , in_vars , shared ):
613
+ def _logp_forw (point , out_vars , in_vars , shared , compile_kwargs = None ):
610
614
"""Compile PyTensor function of the model and the input and output variables.
611
615
612
616
Parameters
@@ -617,7 +621,12 @@ def _logp_forw(point, out_vars, in_vars, shared):
617
621
Containing Distribution for the input variables
618
622
shared : list
619
623
Containing TensorVariable for depended shared data
624
+ compile_kwargs: dict, optional
625
+ Additional keyword arguments passed to pytensor.function
620
626
"""
627
+ if compile_kwargs is None :
628
+ compile_kwargs = {}
629
+
621
630
# Replace integer inputs with rounded float inputs
622
631
if any (var .dtype in discrete_types for var in in_vars ):
623
632
replace_int_input = {}
@@ -636,6 +645,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
636
645
out_list , inarray0 = join_nonshared_inputs (
637
646
point = point , outputs = out_vars , inputs = in_vars , shared_inputs = shared
638
647
)
639
- f = compile ([inarray0 ], out_list [0 ])
648
+ f = compile ([inarray0 ], out_list [0 ], ** compile_kwargs )
640
649
f .trust_input = True
641
650
return f
0 commit comments