Open
Description
When trying to selectively save certain variables from sampling using var_names
in sample_numpyro_nuts
, PyMC throws a mysterious error.
More specifically, I have a large model, but only really care about saving a few of the variables. I'm using numpyro NUTS for sampling on the GPU. To save memory, I found the var_names
argument of sample_numpyro_nuts
, which seemed to be what I'm looking for. But, whenever I use it, it throws the same mysterious error, copied below. This seems like a bug.
Please provide a minimal, self-contained, and reproducible example.
import numpy as np
import pymc as pm
import pymc.sampling_jax
# True parameter values
size = 100
Y = 1 + np.random.normal(size=size, scale = 1)
basic_model = pm.Model()
with basic_model:
scale = pm.HalfNormal("scale", sigma=1)
loc = pm.Normal("loc", mu=0, sigma=10)
Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)
with basic_model:
trace = pymc.sampling_jax.sample_numpyro_nuts(
chains = 1,
tune = 1000,
draws = 1000,
var_names = ["loc"]
)
Please provide the full traceback.
Complete error traceback
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/Users/ryandew/routines/code/pymc/sim/mwe.py in <cell line: 14>()
12 Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)
14 with basic_model:
---> 15 trace = pymc.sampling_jax.sample_numpyro_nuts(
16 chains = 1,
17 tune = 1000,
18 draws = 1000,
19 var_names = ["loc"]
20 )
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:533, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
530 print("Sampling time = ", tic3 - tic2, file=sys.stdout)
532 print("Transforming variables...", file=sys.stdout)
--> 533 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
534 result = jax.vmap(jax.vmap(jax_fn))(
535 *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
536 )
537 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
75 def get_jaxified_graph(
76 inputs: Optional[List[TensorVariable]] = None,
77 outputs: Optional[List[TensorVariable]] = None,
...
850 def expand(r: Variable) -> Optional[Iterator[Variable]]:
--> 851 if r.owner and (not blockers or r not in blockers):
852 return reversed(r.owner.inputs)
AttributeError: 'str' object has no attribute 'owner'
Please provide any additional information below.
Versions and main components
- PyMC version: 4.1.3
- Aesara version: 2.7.7
- Python version: 3.10.5
- Operating system: MacOS (but the error also happens on linux)
- How you installed PyMC: conda