Skip to content

pm.sampling_jax doesnt sample properly with pm.Censored #5897

Open
@kylejcaron

Description

@kylejcaron

Description of your problem

When using pm.sampling_jax with pm.Censored, the sampled posterior contains the same value for all draws within each chain. so for instance if I had 4 chains and 1000 draws per chain, the issue is that I'd see output that appears as follows

import numpy as np

example_output = np.c_[
  np.ones(1000)*0.25,
  np.ones(1000)*0.1,
  np.ones(1000)*0.5,
  np.ones(1000)*0.35
]

print(example_output)

Please provide a minimal, self-contained, and reproducible example.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import arviz as az
import jax.numpy as jnp
from aesara.link.jax.dispatch import jax_funcify
from aesara.scalar import Log1mexp

# implement jax Op for Log1mexp

@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
    def log1mexp(x):
        return jnp.where(
            x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
        )
    
    return log1mexp


np.random.seed(99)
# simulate 500 contexts/units with 100 observations each
contexts = 100
obs_per_context = 100
idxs = np.repeat(range(contexts), obs_per_context)

k_true = np.random.lognormal(0.45, 0.25, contexts)
lambd_true = np.random.lognormal(4.25, 0.5, contexts)

dist = pm.Weibull.dist(k_true[idxs], lambd_true[idxs])
Et = dist.eval()

# Simulate event time data
df_ = pd.DataFrame({
    "group":idxs,
    "event_time":Et
})

# Randomly censor observations
censor_time = np.random.uniform(0,250, size=len(df_))
df = (
    df_
    .assign(censored = lambda df: np.where(df.event_time > censor_time, 1, 0))
    .assign(event_time = lambda df: np.where(df.event_time > censor_time, censor_time, df.event_time) )
)

# Fit model
coords = {"group":df.group.unique()}

with pm.Model(coords=coords) as mW:
    g_ = pm.MutableData("g_", df.group.values)
    y = pm.MutableData("y", df.event_time.values)
    c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.NaN) )
    
    log_k = pm.Normal("log_k", 0.5, 0.5, dims="group")
    log_lambd = pm.Normal("log_lambd", 4.5, 0.5, dims="group")
    
    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")
    y_latent = pm.Weibull.dist(k[g_], lambd[g_])
    y_ = pm.Censored("event", y_latent, lower=None, upper=c_, observed=y)
#    # not using pm.censored samples fine with pm.sampling_jax
#     y_ = pm.Weibull("event", k[g_], lambd[g_], observed=y)
    
    
with mW:
    idata = pm.sampling_jax.sample_numpyro_nuts()
#     idata = pm.sample(init="adapt_diag") # works normally

# returns 4 - one value for each chain
print(len(
    np.unique(
    idata.posterior["log_k"]
            .to_numpy()[:,:,0])
))

# see warning
az.plot_trace(idata, var_names=["log_k"]);

Please provide the full traceback.

Not applicable - there is no raised error in this case, but the functionality definitely is not working as intended as far as I can tell.

Complete error traceback
[The complete error output here]

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: 4.0.0
  • Aesara/Theano Version: 2.6.6
  • Python Version: 3.9.7
  • Operating system: MacOS (not an M1)
  • How did you install PyMC/PyMC3: (conda/pip) pip

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions