@@ -371,6 +371,7 @@ def sample(
371
371
random_seed : RandomState = None ,
372
372
progressbar : bool = True ,
373
373
step = None ,
374
+ var_names : Optional [Sequence [str ]] = None ,
374
375
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
375
376
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
376
377
init : str = "auto" ,
@@ -399,6 +400,7 @@ def sample(
399
400
random_seed : RandomState = None ,
400
401
progressbar : bool = True ,
401
402
step = None ,
403
+ var_names : Optional [Sequence [str ]] = None ,
402
404
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
403
405
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
404
406
init : str = "auto" ,
@@ -427,6 +429,7 @@ def sample(
427
429
random_seed : RandomState = None ,
428
430
progressbar : bool = True ,
429
431
step = None ,
432
+ var_names : Optional [Sequence [str ]] = None ,
430
433
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
431
434
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
432
435
init : str = "auto" ,
@@ -478,6 +481,8 @@ def sample(
478
481
A step function or collection of functions. If there are variables without step methods,
479
482
step methods for those variables will be assigned automatically. By default the NUTS step
480
483
method will be used, if appropriate to the model.
484
+ var_names : list of str
485
+ Names of variables to be monitored. If None, all named variables are selected automatically.
481
486
nuts_sampler : str
482
487
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
483
488
This requires the chosen sampler to be installed.
@@ -722,12 +727,18 @@ def sample(
722
727
model .check_start_vals (ip )
723
728
_check_start_shape (model , ip )
724
729
730
+ if var_names is not None :
731
+ trace_vars = [v for v in model .unobserved_RVs if v .name in var_names ]
732
+ else :
733
+ trace_vars = None
734
+
725
735
# Create trace backends for each chain
726
736
run , traces = init_traces (
727
737
backend = trace ,
728
738
chains = chains ,
729
739
expected_length = draws + tune ,
730
740
step = step ,
741
+ trace_vars = trace_vars ,
731
742
initial_point = ip ,
732
743
model = model ,
733
744
)
@@ -739,6 +750,7 @@ def sample(
739
750
"traces" : traces ,
740
751
"chains" : chains ,
741
752
"tune" : tune ,
753
+ "var_names" : var_names ,
742
754
"progressbar" : progressbar ,
743
755
"model" : model ,
744
756
"cores" : cores ,
0 commit comments