Skip to content

BUG: model initial_point fails when pt.config.floatX = "float32" #7608

Closed
@nataziel

Description

@nataziel

Describe the issue:

I have a hierarchical mmm model setup (with pymc, not pymc-marketing) and have been successfully using it with float32s up to the 5.19 release. It is using numpyro/jax to sample.

With 5.19 I am getting errors in _init_jitter, it appears that there is something going wrong when passing the generated initial points to the compiled logp function. I think the use of zerosumNormal distributions is causing the problem but I'm not sure if it's the values returned by the ipfn(seed) or the evaluation in the compiled model_logp_fn. I've included the verbose model debug return from my model, but the reproducible example below is using the example radon model.

Reproduceable code example:

import pytensor as pt

pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"

import numpy as np

import pandas as pd
import pymc as pm

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")

idata = pm.sample(
    model=model,
    chains=1,
    tune=500,
    draws=500,
    progressbar=False,
    compute_convergence_checks=False,
    return_inferencedata=False,
    # compile_kwargs=dict(mode="NUMBA")
)

Error message:

SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'mu_adstock_logodds__': array([-0.05812943, -1.3623805 , -0.11238763, -0.5489    , -1.0595012 ,
       -1.6993128 , -1.0827137 , -1.0727112 , -1.7406836 , -0.25875893,
       -1.0392814 , -0.94655335,  0.08540058, -0.5202998 ,  0.33476347,
       -0.5190295 , -1.0631133 , -0.57824194, -0.7651843 , -0.87049246],
      dtype=float32), 'mu_lambda_log__': array([0.80796826, 0.6380648 , 1.2703241 , 1.9139447 , 1.5351636 ,
       1.7267554 , 1.0714266 , 1.0567162 , 2.0499995 , 1.5639257 ,
       1.9248102 , 1.6098787 , 0.56195414, 0.22962454, 0.10311574,
       0.36890286, 0.88685906, 1.7213213 , 1.2186754 , 0.36423793],
      dtype=float32), 'mu_a': array(0.37686896, dtype=float32), 'z_a_state_zerosum__': array([ 0.19204606,  0.04281855,  0.5793297 , -0.2614823 , -0.60941106,
       -0.5645286 ], dtype=float32), 'z_a_age_zerosum__': array([-0.39018095,  0.82710207,  0.93642265, -0.30185702,  0.7038302 ],
      dtype=float32), 'z_a_brand_zerosum__': array([-0.19112755], dtype=float32), 'z_a_cohort_zerosum__': array([0.48081324], dtype=float32), 'roas_rv_log__': array([-2.0530562 , -1.1738657 , -2.1438377 , -0.5091131 , -1.7475982 ,
       -0.71832424,  0.7603461 , -0.6206605 , -0.06742464, -1.7866528 ,
       -0.8273122 ,  0.3449509 , -0.4612695 , -1.5151714 , -0.9103267 ,
        0.92829466,  0.05409501, -0.17818904,  0.4767271 ,  1.0659611 ],
      dtype=float32), 'z_b_state_zerosum__': array([[ 0.9404972 ,  0.5173597 ,  0.947904  ,  0.13234697,  0.91516215,
         0.8013973 ],
       [ 0.76825845, -0.00395296,  0.20068735,  0.89948386, -0.8211762 ,
        -0.85832894],
       [-0.23005143, -0.21612945,  0.38762948, -0.17819382, -0.12716204,
         0.5923743 ],
       [ 0.42042527, -0.8524518 ,  0.41847345, -0.27271616,  0.9246246 ,
         0.10817041],
       [ 0.5263319 ,  0.2471719 , -0.7485617 , -0.30133858,  0.48334756,
         0.3941832 ],
       [ 0.20216979,  0.85309595, -0.13996926,  0.13978173,  0.2398329 ,
        -0.6436887 ],
       [ 0.6737369 ,  0.56847805, -0.6305385 ,  0.11638433, -0.88574356,
        -0.28407305],
       [-0.88218284, -0.5086114 ,  0.5899615 ,  0.4378485 ,  0.1587185 ,
        -0.06953551],
       [-0.4915884 ,  0.5323545 ,  0.5476521 ,  0.03720003,  0.93754584,
         0.20691545],
       [ 0.52630943,  0.8224987 ,  0.38608572, -0.48465434, -0.43573558,
         0.8957741 ],
       [-0.36157402, -0.80625093, -0.35036987,  0.58878934, -0.01390497,
        -0.9381669 ],
       [ 0.84330225, -0.24921258, -0.5905653 ,  0.71469283,  0.04288541,
        -0.15973687],
       [ 0.16558522, -0.82091933,  0.6832055 ,  0.9772107 ,  0.05962521,
        -0.7585791 ],
       [-0.34562954, -0.14827304, -0.7131723 , -0.49016312,  0.92288506,
         0.5976615 ],
       [ 0.14006369, -0.78234476,  0.58030427,  0.7087098 ,  0.83111507,
        -0.4035646 ],
       [-0.89755225, -0.7397433 ,  0.87145287,  0.5502826 , -0.75381684,
        -0.08390825],
       [ 0.52683276,  0.2059658 , -0.65518147,  0.9072355 ,  0.02874351,
         0.8622515 ],
       [-0.31716648, -0.22920816, -0.4773452 ,  0.24883471, -0.71415335,
        -0.00118549],
       [ 0.67181724,  0.83293927,  0.06912381,  0.59113485,  0.15970528,
         0.15737481],
       [ 0.91586804, -0.00439743,  0.8568587 , -0.5599965 ,  0.9024629 ,
        -0.92843384]], dtype=float32), 'z_b_age_zerosum__': array([[ 2.76717365e-01, -5.32886267e-01, -4.37204897e-01,
         6.65345013e-01,  9.51409161e-01],
       [ 5.68430305e-01,  5.80501974e-01,  9.02473092e-01,
        -1.05089314e-01, -6.80017248e-02],
       [-7.40278482e-01, -6.54472530e-01,  5.73029280e-01,
        -2.49546096e-01, -9.83492434e-01],
       [ 4.39474136e-01,  4.55991507e-01,  2.91431248e-01,
         7.93459237e-01,  8.33085358e-01],
       [ 7.54141450e-01, -9.88980412e-01,  7.37549663e-01,
        -9.54164326e-01,  3.69425505e-01],
       [-4.45214868e-01,  1.18648775e-01,  4.35143918e-01,
         7.96567798e-01, -5.37025869e-01],
       [-8.19233358e-01, -3.28816384e-01, -4.24525887e-01,
        -4.72912073e-01, -2.51088679e-01],
       [ 4.16932464e-01,  1.86953232e-01, -2.34448180e-01,
        -5.28278828e-01, -7.83707380e-01],
       [-2.08375111e-01, -1.69877082e-01, -9.20472383e-01,
        -5.55105388e-01,  2.24135935e-01],
       [-7.18319640e-02, -8.23212624e-01, -1.14380375e-01,
         3.75080615e-01,  4.38587993e-01],
       [-7.84464240e-01,  9.64653268e-02,  1.33498237e-01,
         2.63148099e-01,  9.03292537e-01],
       [ 3.70964020e-01,  8.96655202e-01, -9.78391707e-01,
         6.00353360e-01, -3.29210430e-01],
       [ 9.80111659e-01,  6.26725018e-01,  8.71558905e-01,
        -7.08010912e-01,  3.21216695e-02],
       [-8.55567873e-01, -4.15038317e-01, -2.70858496e-01,
         7.64281690e-01,  1.69419169e-01],
       [-6.30245388e-01,  5.22969842e-01, -6.22790098e-01,
         8.40588808e-01,  5.42818129e-01],
       [-2.61249971e-02,  5.77672958e-01, -9.52823997e-01,
        -5.49517214e-01, -4.92883384e-01],
       [ 3.19638193e-01,  8.80902350e-01,  2.54505854e-02,
        -4.16665673e-01,  7.45047331e-01],
       [-1.37079775e-01, -9.72663925e-04,  2.88793862e-01,
         9.96275783e-01,  5.16300082e-01],
       [-2.26764768e-01, -9.21454072e-01, -2.66458213e-01,
        -2.89255470e-01, -5.44357836e-01],
       [-7.08415627e-01, -2.39693552e-01,  1.69611976e-01,
         6.88308775e-01, -2.90724158e-01]], dtype=float32), 'z_b_brand_zerosum__': array([[ 0.5283236 ],
       [-0.14984424],
       [ 0.5653758 ],
       [ 0.6604869 ],
       [ 0.5594151 ],
       [ 0.36363953],
       [-0.13847719],
       [ 0.2760732 ],
       [ 0.60931265],
       [ 0.39675766],
       [ 0.13061056],
       [-0.843226  ],
       [-0.24025336],
       [ 0.21590135],
       [ 0.38261482],
       [-0.9853659 ],
       [-0.89518636],
       [-0.73512644],
       [ 0.24093248],
       [ 0.53579485]], dtype=float32), 'z_b_cohort_zerosum__': array([[-0.89432555],
       [ 0.56194794],
       [ 0.8784441 ],
       [ 0.29107055],
       [-0.03464561],
       [-0.7969992 ],
       [ 0.3803715 ],
       [ 0.14672193],
       [-0.72870535],
       [-0.9154726 ],
       [-0.05452913],
       [ 0.3046088 ],
       [ 0.11125816],
       [ 0.47531176],
       [-0.26566276],
       [-0.97336334],
       [ 0.8083869 ],
       [ 0.10919274],
       [ 0.34626842],
       [-0.14968246]], dtype=float32), 'mu_b_pos_con': array([-0.55350554, -1.271216  , -2.4323664 , -2.2574682 , -1.7525793 ],
      dtype=float32), 'z_b_pos_con_state_zerosum__': array([[ 0.7088774 , -0.64973927,  0.01091878,  0.42494458, -0.5738007 ,
        -0.13388953],
       [ 0.2020803 , -0.4415184 ,  0.9188914 , -0.91647035,  0.20404132,
        -0.19665638],
       [-0.460112  ,  0.250589  ,  0.43331107, -0.31884784, -0.20407897,
        -0.85575414],
       [ 0.053673  , -0.31756902,  0.743054  , -0.83611   ,  0.24574716,
         0.31937018],
       [-0.24382843, -0.7048149 , -0.29253975, -0.2640002 , -0.7875447 ,
         0.9551751 ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[-0.968712  , -0.18256   ,  0.54134816,  0.5418771 , -0.20875828],
       [ 0.07226439,  0.3872751 ,  0.78559935,  0.56196845, -0.21816622],
       [-0.07638574,  0.56413966, -0.29065445, -0.8537547 ,  0.26876295],
       [-0.03170992,  0.7177836 , -0.14550653,  0.9839035 ,  0.05719684],
       [-0.09502391,  0.7952581 , -0.3214874 ,  0.5325462 , -0.84900486]],
      dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[ 0.9145228 ],
       [-0.576082  ],
       [-0.9886196 ],
       [-0.76018703],
       [ 0.8999498 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.32495368],
       [-0.32046455],
       [ 0.20123298],
       [-0.7780437 ],
       [-0.78692883]], dtype=float32), 'mu_b_neg_con': array([-1.6263036], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[ 0.6405368 ,  0.02604678,  0.7836906 , -0.9528276 ,  0.65443355,
        -0.00763485]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[-0.88560593, -0.3676332 , -0.1298519 , -0.39139104,  0.05216115]],
      dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.06288974]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[-0.2893259]], dtype=float32), 'mu_b_lag': array([-3.4769938], dtype=float32), 'z_b_lag_state_zerosum__': array([[-0.5544768 ,  0.8843655 , -0.85776746, -0.54040647,  0.83116657,
        -0.91366553]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.27069455,  0.2308255 ,  0.02965806, -0.922125  ,  0.25014895]],
      dtype=float32), 'z_b_lag_brand_zerosum__': array([[-0.95589054]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.9150893]], dtype=float32), 'mu_b_fourier_year': array([ 0.99425113,  0.6344655 , -0.11898322,  0.8841637 ,  0.21039897,
        0.07808983,  0.8532348 ,  0.44000137,  0.98853546, -0.52862114,
        0.79177517, -0.01937965, -0.7798197 , -0.8342347 , -0.64850366,
        0.97407717,  0.95000124,  0.7092928 ,  0.6113028 , -0.27019796],
      dtype=float32), 'sd_y_log__': array(4.4781275, dtype=float32)}

Logp initial evaluation results:
{'mu_adstock': -20.54, 'mu_lambda': -213.9, 'mu_a': -1.3, 'z_a_state': -inf, 'z_a_age': -inf, 'z_a_brand': -0.44, 'z_a_cohort': -inf, 'roas_rv': -72.16, 'z_b_state': -inf, 'z_b_age': -inf, 'z_b_brand': -inf, 'z_b_cohort': -inf, 'mu_b_pos_con': -5.82, 'z_b_pos_con_state': -inf, 'z_b_pos_con_age': -inf, 'z_b_pos_con_brand': -inf, 'z_b_pos_con_cohort': -66.75, 'mu_b_neg_con': -1.12, 'z_b_neg_con_state': -inf, 'z_b_neg_con_age': -inf, 'z_b_neg_con_brand': 1.19, 'z_b_neg_con_cohort': -2.8, 'mu_b_lag': -1.03, 'z_b_lag_state': -inf, 'z_b_lag_age': -inf, 'z_b_lag_brand': -44.3, 'z_b_lag_cohort': -40.49, 'mu_b_fourier_year': -91.0, 'sd_y': -1.21, 'y_like': -56160.32}
You can call `model.debug()` for more details.

SamplingError                             Traceback (most recent call last)
File ~/.ipykernel/1917/command--1-2479103431:18
     15 entry = [ep for ep in metadata.distribution("mmm_v2").entry_points if ep.name == "mmm"]
     16 if entry:
     17   # Load and execute the entrypoint, assumes no parameters
---> 18   entry[0].load()()
     19 else:
     20   import importlib

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/mmm_v2/main.py:16, in main()
     13 pt.config.floatX = "float32"  # pyright: ignore[reportPrivateImportUsage] // TODO: we can probably set this via env vars or pytensor.rc
     14 config = setup_mmm()
---> 16 run_mmm(config)

-- elided non important dataprep pipeline stuff --

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/mmm_v2/model/ModelBuilder.py:634, in HLM_ModelBuilder.fit(self)
    631 """Fits the model with the pymc model `sample()` method."""
    632 self.logger.debug("Sampling model.")
    633 self.model_trace.extend(
--> 634     pm.sample(
    635         draws=self.config.draws,
    636         tune=self.config.tune,
    637         chains=self.config.chains,
    638         model=self.model,
    639         nuts_sampler="numpyro",
    640         idata_kwargs={"log_likelihood": False},
    641         var_names=[f"{x}" for x in self.model.free_RVs],
    642         target_accept=self.config.target_accept,
    643         random_seed=self.config.sampler_seed,
    644         progressbar=self.config.progress_bars,
    645     )
    646 )
    647 self.logger.debug("Finished sampling.")
    649 HLM_ModelBuilder._print_diagnostics(self.model_trace, self.logger)

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:773, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    768         raise ValueError(
    769             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    770         )
    772     with joined_blas_limiter():
--> 773         return _sample_external_nuts(
    774             sampler=nuts_sampler,
    775             draws=draws,
    776             tune=tune,
    777             chains=chains,
    778             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    779             random_seed=random_seed,
    780             initvals=initvals,
    781             model=model,
    782             var_names=var_names,
    783             progressbar=progressbar,
    784             idata_kwargs=idata_kwargs,
    785             compute_convergence_checks=compute_convergence_checks,
    786             nuts_sampler_kwargs=nuts_sampler_kwargs,
    787             **kwargs,
    788         )
    790 if exclusive_nuts and not provided_steps:
    791     # Special path for NUTS initialization
    792     if "nuts" in kwargs:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:389, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    386 elif sampler in ("numpyro", "blackjax"):
    387     import pymc.sampling.jax as pymc_jax
--> 389     idata = pymc_jax.sample_jax_nuts(
    390         draws=draws,
    391         tune=tune,
    392         chains=chains,
    393         target_accept=target_accept,
    394         random_seed=random_seed,
    395         initvals=initvals,
    396         model=model,
    397         var_names=var_names,
    398         progressbar=progressbar,
    399         nuts_sampler=sampler,
    400         idata_kwargs=idata_kwargs,
    401         compute_convergence_checks=compute_convergence_checks,
    402         **nuts_sampler_kwargs,
    403     )
    404     return idata
    406 else:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/jax.py:595, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    589 vars_to_sample = list(
    590     get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
    591 )
    593 (random_seed,) = _get_seeds_per_chain(random_seed, 1)
--> 595 initial_points = _get_batched_jittered_initial_points(
    596     model=model,
    597     chains=chains,
    598     initvals=initvals,
    599     random_seed=random_seed,
    600     jitter=jitter,
    601 )
    603 if nuts_sampler == "numpyro":
    604     sampler_fn = _sample_numpyro_nuts

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/jax.py:225, in _get_batched_jittered_initial_points(model, chains, initvals, random_seed, jitter, jitter_max_retries)
    209 def _get_batched_jittered_initial_points(
    210     model: Model,
    211     chains: int,
   (...)
    215     jitter_max_retries: int = 10,
    216 ) -> np.ndarray | list[np.ndarray]:
    217     """Get jittered initial point in format expected by NumPyro MCMC kernel.
    218 
    219     Returns
   (...)
    223         Each item has shape `(chains, *var.shape)`
    224     """
--> 225     initial_points = _init_jitter(
    226         model,
    227         initvals,
    228         seeds=_get_seeds_per_chain(random_seed, chains),
    229         jitter=jitter,
    230         jitter_max_retries=jitter_max_retries,
    231     )
    232     initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
    233     if chains == 1:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:1382, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_dlogp_func)
   1379 if not np.isfinite(point_logp):
   1380     if i == jitter_max_retries:
   1381         # Print informative message on last attempted point
-> 1382         model.check_start_vals(point)
   1383     # Retry with a new seed
   1384     seed = rng.integers(2**30, dtype=np.int64)

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/model/core.py:1769, in Model.check_start_vals(self, start, **kwargs)
   1766 initial_eval = self.point_logps(point=elem, **kwargs)
   1768 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1769     raise SamplingError(
   1770         "Initial evaluation of model at starting point failed!\n"
   1771         f"Starting values:\n{elem}\n\n"
   1772         f"Logp initial evaluation results:\n{initial_eval}\n"
   1773         "You can call `model.debug()` for more details."
   1774     )

PyMC version information:

Running on a windows machine in a linux container
pymc installed via poetry(pypi).
using libopenblas

annotated-types 0.7.0
arviz 0.19.0
babel 2.16.0
blinker 1.4
build 1.2.2.post1
CacheControl 0.14.1
cachetools 5.5.0
certifi 2024.8.30
cfgv 3.4.0
charset-normalizer 3.4.0
cleo 2.1.0
click 8.1.7
cloudpickle 3.1.0
colorama 0.4.6
cons 0.4.6
contourpy 1.3.1
coverage 7.6.8
crashtest 0.4.1
cryptography 3.4.8
cycler 0.12.1
dbus-python 1.2.18
distlib 0.3.9
distro 1.7.0
distro-info 1.1+ubuntu0.2
dm-tree 0.1.8
dulwich 0.21.7
etuples 0.3.9
exceptiongroup 1.2.2
fastjsonschema 2.21.1
filelock 3.16.1
fonttools 4.55.1
ghp-import 2.1.0
graphviz 0.20.3
griffe 1.5.1
h5netcdf 1.4.1
h5py 3.12.1
httplib2 0.20.2
identify 2.6.3
idna 3.10
importlib_metadata 8.5.0
iniconfig 2.0.0
installer 0.7.0
jaraco.classes 3.4.0
jax 0.4.35
jaxlib 0.4.35
jeepney 0.7.1
Jinja2 3.1.4
joblib 1.4.2
keyring 24.3.1
kiwisolver 1.4.7
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
logical-unification 0.4.6
loguru 0.7.2
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.9.3
mdurl 0.1.2
mergedeep 1.3.4
miniKanren 1.0.3
mkdocs 1.6.1
mkdocs-autorefs 1.2.0
mkdocs-gen-files 0.5.0
mkdocs-get-deps 0.2.0
mkdocs-glightbox 0.4.0
mkdocs-literate-nav 0.6.1
mkdocs-material 9.5.47
mkdocs-material-extensions 1.3.1
mkdocs-section-index 0.3.9
mkdocstrings 0.26.2
mkdocstrings-python 1.12.2
ml_dtypes 0.5.0
mmm_v2 0.0.1 /workspaces/mmm_v2
more-itertools 8.10.0
msgpack 1.1.0
multimethod 1.10
multipledispatch 1.0.0
mypy-extensions 1.0.0
nodeenv 1.9.1
numpy 1.26.4
numpyro 0.15.3
oauthlib 3.2.0
opt_einsum 3.4.0
packaging 24.2
paginate 0.5.7
pandas 2.2.3
pandera 0.20.4
pathspec 0.12.1
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
pkginfo 1.12.0
platformdirs 4.3.6
pluggy 1.5.0
poetry 1.8.4
poetry-core 1.9.1
poetry-plugin-export 1.8.0
pre_commit 4.0.1
ptyprocess 0.7.0
pyarrow 18.1.0
pydantic 2.10.3
pydantic_core 2.27.1
Pygments 2.18.0
PyGObject 3.42.1
PyJWT 2.3.0
pymc 5.19.1
pymc-marketing 0.6.0
pymdown-extensions 10.12
pyparsing 3.2.0
pyproject_hooks 1.2.0
pytensor 2.26.4
pytest 8.3.4
pytest-cov 6.0.0
python-apt 2.4.0+ubuntu3
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
pyyaml_env_tag 0.1
RapidFuzz 3.10.1
regex 2024.11.6
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
ruff 0.8.1
scikit-learn 1.5.2
scipy 1.14.1
seaborn 0.13.2
SecretStorage 3.3.1
setuptools 75.6.0
shellingham 1.5.4
six 1.17.0
ssh-import-id 5.11
threadpoolctl 3.5.0
tomli 2.2.1
tomlkit 0.13.2
toolz 1.0.0
tqdm 4.67.1
trove-classifiers 2024.10.21.16
typeguard 4.4.1
typing_extensions 4.12.2
typing-inspect 0.9.0
tzdata 2024.2
unattended-upgrades 0.1
urllib3 2.2.3
virtualenv 20.28.0
wadllib 1.3.6
watchdog 6.0.0
wheel 0.38.4
wrapt 1.17.0
xarray 2024.9.0
xarray-einstats 0.8.0
zipp 3.21.0

Context for the issue:

I want to use f32 due to memory constraints, I've got a big model and dataset with a hierarchy that takes up 40+gb of ram when running using f32 so I'd need a huge box to go up to f64.

Maybe the optimisations in 5.19 make that problem moot? I've managed to get it running with f64 temporarily but I'm not sure if it'll be a long term solution

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions