Open
Description
Description of your problem
When using pm.sampling_jax
with pm.Censored
, the sampled posterior contains the same value for all draws within each chain. so for instance if I had 4 chains and 1000 draws per chain, the issue is that I'd see output that appears as follows
import numpy as np
example_output = np.c_[
np.ones(1000)*0.25,
np.ones(1000)*0.1,
np.ones(1000)*0.5,
np.ones(1000)*0.35
]
print(example_output)
Please provide a minimal, self-contained, and reproducible example.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import arviz as az
import jax.numpy as jnp
from aesara.link.jax.dispatch import jax_funcify
from aesara.scalar import Log1mexp
# implement jax Op for Log1mexp
@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
return jnp.where(
x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
)
return log1mexp
np.random.seed(99)
# simulate 500 contexts/units with 100 observations each
contexts = 100
obs_per_context = 100
idxs = np.repeat(range(contexts), obs_per_context)
k_true = np.random.lognormal(0.45, 0.25, contexts)
lambd_true = np.random.lognormal(4.25, 0.5, contexts)
dist = pm.Weibull.dist(k_true[idxs], lambd_true[idxs])
Et = dist.eval()
# Simulate event time data
df_ = pd.DataFrame({
"group":idxs,
"event_time":Et
})
# Randomly censor observations
censor_time = np.random.uniform(0,250, size=len(df_))
df = (
df_
.assign(censored = lambda df: np.where(df.event_time > censor_time, 1, 0))
.assign(event_time = lambda df: np.where(df.event_time > censor_time, censor_time, df.event_time) )
)
# Fit model
coords = {"group":df.group.unique()}
with pm.Model(coords=coords) as mW:
g_ = pm.MutableData("g_", df.group.values)
y = pm.MutableData("y", df.event_time.values)
c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.NaN) )
log_k = pm.Normal("log_k", 0.5, 0.5, dims="group")
log_lambd = pm.Normal("log_lambd", 4.5, 0.5, dims="group")
k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")
y_latent = pm.Weibull.dist(k[g_], lambd[g_])
y_ = pm.Censored("event", y_latent, lower=None, upper=c_, observed=y)
# # not using pm.censored samples fine with pm.sampling_jax
# y_ = pm.Weibull("event", k[g_], lambd[g_], observed=y)
with mW:
idata = pm.sampling_jax.sample_numpyro_nuts()
# idata = pm.sample(init="adapt_diag") # works normally
# returns 4 - one value for each chain
print(len(
np.unique(
idata.posterior["log_k"]
.to_numpy()[:,:,0])
))
# see warning
az.plot_trace(idata, var_names=["log_k"]);
Please provide the full traceback.
Not applicable - there is no raised error in this case, but the functionality definitely is not working as intended as far as I can tell.
Complete error traceback
[The complete error output here]
Please provide any additional information below.
Versions and main components
- PyMC/PyMC3 Version: 4.0.0
- Aesara/Theano Version: 2.6.6
- Python Version: 3.9.7
- Operating system: MacOS (not an M1)
- How did you install PyMC/PyMC3: (conda/pip) pip