Description
Describe the issue:
I have a hierarchical mmm model setup (with pymc, not pymc-marketing) and have been successfully using it with float32
s up to the 5.19 release. It is using numpyro/jax to sample.
With 5.19 I am getting errors in _init_jitter
, it appears that there is something going wrong when passing the generated initial points to the compiled logp function. I think the use of zerosumNormal distributions is causing the problem but I'm not sure if it's the values returned by the ipfn(seed)
or the evaluation in the compiled model_logp_fn
. I've included the verbose model debug return from my model, but the reproducible example below is using the example radon model.
Reproduceable code example:
import pytensor as pt
pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"
import numpy as np
import pandas as pd
import pymc as pm
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")
idata = pm.sample(
model=model,
chains=1,
tune=500,
draws=500,
progressbar=False,
compute_convergence_checks=False,
return_inferencedata=False,
# compile_kwargs=dict(mode="NUMBA")
)
Error message:
SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'mu_adstock_logodds__': array([-0.05812943, -1.3623805 , -0.11238763, -0.5489 , -1.0595012 ,
-1.6993128 , -1.0827137 , -1.0727112 , -1.7406836 , -0.25875893,
-1.0392814 , -0.94655335, 0.08540058, -0.5202998 , 0.33476347,
-0.5190295 , -1.0631133 , -0.57824194, -0.7651843 , -0.87049246],
dtype=float32), 'mu_lambda_log__': array([0.80796826, 0.6380648 , 1.2703241 , 1.9139447 , 1.5351636 ,
1.7267554 , 1.0714266 , 1.0567162 , 2.0499995 , 1.5639257 ,
1.9248102 , 1.6098787 , 0.56195414, 0.22962454, 0.10311574,
0.36890286, 0.88685906, 1.7213213 , 1.2186754 , 0.36423793],
dtype=float32), 'mu_a': array(0.37686896, dtype=float32), 'z_a_state_zerosum__': array([ 0.19204606, 0.04281855, 0.5793297 , -0.2614823 , -0.60941106,
-0.5645286 ], dtype=float32), 'z_a_age_zerosum__': array([-0.39018095, 0.82710207, 0.93642265, -0.30185702, 0.7038302 ],
dtype=float32), 'z_a_brand_zerosum__': array([-0.19112755], dtype=float32), 'z_a_cohort_zerosum__': array([0.48081324], dtype=float32), 'roas_rv_log__': array([-2.0530562 , -1.1738657 , -2.1438377 , -0.5091131 , -1.7475982 ,
-0.71832424, 0.7603461 , -0.6206605 , -0.06742464, -1.7866528 ,
-0.8273122 , 0.3449509 , -0.4612695 , -1.5151714 , -0.9103267 ,
0.92829466, 0.05409501, -0.17818904, 0.4767271 , 1.0659611 ],
dtype=float32), 'z_b_state_zerosum__': array([[ 0.9404972 , 0.5173597 , 0.947904 , 0.13234697, 0.91516215,
0.8013973 ],
[ 0.76825845, -0.00395296, 0.20068735, 0.89948386, -0.8211762 ,
-0.85832894],
[-0.23005143, -0.21612945, 0.38762948, -0.17819382, -0.12716204,
0.5923743 ],
[ 0.42042527, -0.8524518 , 0.41847345, -0.27271616, 0.9246246 ,
0.10817041],
[ 0.5263319 , 0.2471719 , -0.7485617 , -0.30133858, 0.48334756,
0.3941832 ],
[ 0.20216979, 0.85309595, -0.13996926, 0.13978173, 0.2398329 ,
-0.6436887 ],
[ 0.6737369 , 0.56847805, -0.6305385 , 0.11638433, -0.88574356,
-0.28407305],
[-0.88218284, -0.5086114 , 0.5899615 , 0.4378485 , 0.1587185 ,
-0.06953551],
[-0.4915884 , 0.5323545 , 0.5476521 , 0.03720003, 0.93754584,
0.20691545],
[ 0.52630943, 0.8224987 , 0.38608572, -0.48465434, -0.43573558,
0.8957741 ],
[-0.36157402, -0.80625093, -0.35036987, 0.58878934, -0.01390497,
-0.9381669 ],
[ 0.84330225, -0.24921258, -0.5905653 , 0.71469283, 0.04288541,
-0.15973687],
[ 0.16558522, -0.82091933, 0.6832055 , 0.9772107 , 0.05962521,
-0.7585791 ],
[-0.34562954, -0.14827304, -0.7131723 , -0.49016312, 0.92288506,
0.5976615 ],
[ 0.14006369, -0.78234476, 0.58030427, 0.7087098 , 0.83111507,
-0.4035646 ],
[-0.89755225, -0.7397433 , 0.87145287, 0.5502826 , -0.75381684,
-0.08390825],
[ 0.52683276, 0.2059658 , -0.65518147, 0.9072355 , 0.02874351,
0.8622515 ],
[-0.31716648, -0.22920816, -0.4773452 , 0.24883471, -0.71415335,
-0.00118549],
[ 0.67181724, 0.83293927, 0.06912381, 0.59113485, 0.15970528,
0.15737481],
[ 0.91586804, -0.00439743, 0.8568587 , -0.5599965 , 0.9024629 ,
-0.92843384]], dtype=float32), 'z_b_age_zerosum__': array([[ 2.76717365e-01, -5.32886267e-01, -4.37204897e-01,
6.65345013e-01, 9.51409161e-01],
[ 5.68430305e-01, 5.80501974e-01, 9.02473092e-01,
-1.05089314e-01, -6.80017248e-02],
[-7.40278482e-01, -6.54472530e-01, 5.73029280e-01,
-2.49546096e-01, -9.83492434e-01],
[ 4.39474136e-01, 4.55991507e-01, 2.91431248e-01,
7.93459237e-01, 8.33085358e-01],
[ 7.54141450e-01, -9.88980412e-01, 7.37549663e-01,
-9.54164326e-01, 3.69425505e-01],
[-4.45214868e-01, 1.18648775e-01, 4.35143918e-01,
7.96567798e-01, -5.37025869e-01],
[-8.19233358e-01, -3.28816384e-01, -4.24525887e-01,
-4.72912073e-01, -2.51088679e-01],
[ 4.16932464e-01, 1.86953232e-01, -2.34448180e-01,
-5.28278828e-01, -7.83707380e-01],
[-2.08375111e-01, -1.69877082e-01, -9.20472383e-01,
-5.55105388e-01, 2.24135935e-01],
[-7.18319640e-02, -8.23212624e-01, -1.14380375e-01,
3.75080615e-01, 4.38587993e-01],
[-7.84464240e-01, 9.64653268e-02, 1.33498237e-01,
2.63148099e-01, 9.03292537e-01],
[ 3.70964020e-01, 8.96655202e-01, -9.78391707e-01,
6.00353360e-01, -3.29210430e-01],
[ 9.80111659e-01, 6.26725018e-01, 8.71558905e-01,
-7.08010912e-01, 3.21216695e-02],
[-8.55567873e-01, -4.15038317e-01, -2.70858496e-01,
7.64281690e-01, 1.69419169e-01],
[-6.30245388e-01, 5.22969842e-01, -6.22790098e-01,
8.40588808e-01, 5.42818129e-01],
[-2.61249971e-02, 5.77672958e-01, -9.52823997e-01,
-5.49517214e-01, -4.92883384e-01],
[ 3.19638193e-01, 8.80902350e-01, 2.54505854e-02,
-4.16665673e-01, 7.45047331e-01],
[-1.37079775e-01, -9.72663925e-04, 2.88793862e-01,
9.96275783e-01, 5.16300082e-01],
[-2.26764768e-01, -9.21454072e-01, -2.66458213e-01,
-2.89255470e-01, -5.44357836e-01],
[-7.08415627e-01, -2.39693552e-01, 1.69611976e-01,
6.88308775e-01, -2.90724158e-01]], dtype=float32), 'z_b_brand_zerosum__': array([[ 0.5283236 ],
[-0.14984424],
[ 0.5653758 ],
[ 0.6604869 ],
[ 0.5594151 ],
[ 0.36363953],
[-0.13847719],
[ 0.2760732 ],
[ 0.60931265],
[ 0.39675766],
[ 0.13061056],
[-0.843226 ],
[-0.24025336],
[ 0.21590135],
[ 0.38261482],
[-0.9853659 ],
[-0.89518636],
[-0.73512644],
[ 0.24093248],
[ 0.53579485]], dtype=float32), 'z_b_cohort_zerosum__': array([[-0.89432555],
[ 0.56194794],
[ 0.8784441 ],
[ 0.29107055],
[-0.03464561],
[-0.7969992 ],
[ 0.3803715 ],
[ 0.14672193],
[-0.72870535],
[-0.9154726 ],
[-0.05452913],
[ 0.3046088 ],
[ 0.11125816],
[ 0.47531176],
[-0.26566276],
[-0.97336334],
[ 0.8083869 ],
[ 0.10919274],
[ 0.34626842],
[-0.14968246]], dtype=float32), 'mu_b_pos_con': array([-0.55350554, -1.271216 , -2.4323664 , -2.2574682 , -1.7525793 ],
dtype=float32), 'z_b_pos_con_state_zerosum__': array([[ 0.7088774 , -0.64973927, 0.01091878, 0.42494458, -0.5738007 ,
-0.13388953],
[ 0.2020803 , -0.4415184 , 0.9188914 , -0.91647035, 0.20404132,
-0.19665638],
[-0.460112 , 0.250589 , 0.43331107, -0.31884784, -0.20407897,
-0.85575414],
[ 0.053673 , -0.31756902, 0.743054 , -0.83611 , 0.24574716,
0.31937018],
[-0.24382843, -0.7048149 , -0.29253975, -0.2640002 , -0.7875447 ,
0.9551751 ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[-0.968712 , -0.18256 , 0.54134816, 0.5418771 , -0.20875828],
[ 0.07226439, 0.3872751 , 0.78559935, 0.56196845, -0.21816622],
[-0.07638574, 0.56413966, -0.29065445, -0.8537547 , 0.26876295],
[-0.03170992, 0.7177836 , -0.14550653, 0.9839035 , 0.05719684],
[-0.09502391, 0.7952581 , -0.3214874 , 0.5325462 , -0.84900486]],
dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[ 0.9145228 ],
[-0.576082 ],
[-0.9886196 ],
[-0.76018703],
[ 0.8999498 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.32495368],
[-0.32046455],
[ 0.20123298],
[-0.7780437 ],
[-0.78692883]], dtype=float32), 'mu_b_neg_con': array([-1.6263036], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[ 0.6405368 , 0.02604678, 0.7836906 , -0.9528276 , 0.65443355,
-0.00763485]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[-0.88560593, -0.3676332 , -0.1298519 , -0.39139104, 0.05216115]],
dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.06288974]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[-0.2893259]], dtype=float32), 'mu_b_lag': array([-3.4769938], dtype=float32), 'z_b_lag_state_zerosum__': array([[-0.5544768 , 0.8843655 , -0.85776746, -0.54040647, 0.83116657,
-0.91366553]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.27069455, 0.2308255 , 0.02965806, -0.922125 , 0.25014895]],
dtype=float32), 'z_b_lag_brand_zerosum__': array([[-0.95589054]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.9150893]], dtype=float32), 'mu_b_fourier_year': array([ 0.99425113, 0.6344655 , -0.11898322, 0.8841637 , 0.21039897,
0.07808983, 0.8532348 , 0.44000137, 0.98853546, -0.52862114,
0.79177517, -0.01937965, -0.7798197 , -0.8342347 , -0.64850366,
0.97407717, 0.95000124, 0.7092928 , 0.6113028 , -0.27019796],
dtype=float32), 'sd_y_log__': array(4.4781275, dtype=float32)}
Logp initial evaluation results:
{'mu_adstock': -20.54, 'mu_lambda': -213.9, 'mu_a': -1.3, 'z_a_state': -inf, 'z_a_age': -inf, 'z_a_brand': -0.44, 'z_a_cohort': -inf, 'roas_rv': -72.16, 'z_b_state': -inf, 'z_b_age': -inf, 'z_b_brand': -inf, 'z_b_cohort': -inf, 'mu_b_pos_con': -5.82, 'z_b_pos_con_state': -inf, 'z_b_pos_con_age': -inf, 'z_b_pos_con_brand': -inf, 'z_b_pos_con_cohort': -66.75, 'mu_b_neg_con': -1.12, 'z_b_neg_con_state': -inf, 'z_b_neg_con_age': -inf, 'z_b_neg_con_brand': 1.19, 'z_b_neg_con_cohort': -2.8, 'mu_b_lag': -1.03, 'z_b_lag_state': -inf, 'z_b_lag_age': -inf, 'z_b_lag_brand': -44.3, 'z_b_lag_cohort': -40.49, 'mu_b_fourier_year': -91.0, 'sd_y': -1.21, 'y_like': -56160.32}
You can call `model.debug()` for more details.
SamplingError Traceback (most recent call last)
File ~/.ipykernel/1917/command--1-2479103431:18
15 entry = [ep for ep in metadata.distribution("mmm_v2").entry_points if ep.name == "mmm"]
16 if entry:
17 # Load and execute the entrypoint, assumes no parameters
---> 18 entry[0].load()()
19 else:
20 import importlib
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/mmm_v2/main.py:16, in main()
13 pt.config.floatX = "float32" # pyright: ignore[reportPrivateImportUsage] // TODO: we can probably set this via env vars or pytensor.rc
14 config = setup_mmm()
---> 16 run_mmm(config)
-- elided non important dataprep pipeline stuff --
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/mmm_v2/model/ModelBuilder.py:634, in HLM_ModelBuilder.fit(self)
631 """Fits the model with the pymc model `sample()` method."""
632 self.logger.debug("Sampling model.")
633 self.model_trace.extend(
--> 634 pm.sample(
635 draws=self.config.draws,
636 tune=self.config.tune,
637 chains=self.config.chains,
638 model=self.model,
639 nuts_sampler="numpyro",
640 idata_kwargs={"log_likelihood": False},
641 var_names=[f"{x}" for x in self.model.free_RVs],
642 target_accept=self.config.target_accept,
643 random_seed=self.config.sampler_seed,
644 progressbar=self.config.progress_bars,
645 )
646 )
647 self.logger.debug("Finished sampling.")
649 HLM_ModelBuilder._print_diagnostics(self.model_trace, self.logger)
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:773, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
768 raise ValueError(
769 "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
770 )
772 with joined_blas_limiter():
--> 773 return _sample_external_nuts(
774 sampler=nuts_sampler,
775 draws=draws,
776 tune=tune,
777 chains=chains,
778 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
779 random_seed=random_seed,
780 initvals=initvals,
781 model=model,
782 var_names=var_names,
783 progressbar=progressbar,
784 idata_kwargs=idata_kwargs,
785 compute_convergence_checks=compute_convergence_checks,
786 nuts_sampler_kwargs=nuts_sampler_kwargs,
787 **kwargs,
788 )
790 if exclusive_nuts and not provided_steps:
791 # Special path for NUTS initialization
792 if "nuts" in kwargs:
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:389, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
386 elif sampler in ("numpyro", "blackjax"):
387 import pymc.sampling.jax as pymc_jax
--> 389 idata = pymc_jax.sample_jax_nuts(
390 draws=draws,
391 tune=tune,
392 chains=chains,
393 target_accept=target_accept,
394 random_seed=random_seed,
395 initvals=initvals,
396 model=model,
397 var_names=var_names,
398 progressbar=progressbar,
399 nuts_sampler=sampler,
400 idata_kwargs=idata_kwargs,
401 compute_convergence_checks=compute_convergence_checks,
402 **nuts_sampler_kwargs,
403 )
404 return idata
406 else:
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/jax.py:595, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
589 vars_to_sample = list(
590 get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
591 )
593 (random_seed,) = _get_seeds_per_chain(random_seed, 1)
--> 595 initial_points = _get_batched_jittered_initial_points(
596 model=model,
597 chains=chains,
598 initvals=initvals,
599 random_seed=random_seed,
600 jitter=jitter,
601 )
603 if nuts_sampler == "numpyro":
604 sampler_fn = _sample_numpyro_nuts
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/jax.py:225, in _get_batched_jittered_initial_points(model, chains, initvals, random_seed, jitter, jitter_max_retries)
209 def _get_batched_jittered_initial_points(
210 model: Model,
211 chains: int,
(...)
215 jitter_max_retries: int = 10,
216 ) -> np.ndarray | list[np.ndarray]:
217 """Get jittered initial point in format expected by NumPyro MCMC kernel.
218
219 Returns
(...)
223 Each item has shape `(chains, *var.shape)`
224 """
--> 225 initial_points = _init_jitter(
226 model,
227 initvals,
228 seeds=_get_seeds_per_chain(random_seed, chains),
229 jitter=jitter,
230 jitter_max_retries=jitter_max_retries,
231 )
232 initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
233 if chains == 1:
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:1382, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_dlogp_func)
1379 if not np.isfinite(point_logp):
1380 if i == jitter_max_retries:
1381 # Print informative message on last attempted point
-> 1382 model.check_start_vals(point)
1383 # Retry with a new seed
1384 seed = rng.integers(2**30, dtype=np.int64)
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/model/core.py:1769, in Model.check_start_vals(self, start, **kwargs)
1766 initial_eval = self.point_logps(point=elem, **kwargs)
1768 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1769 raise SamplingError(
1770 "Initial evaluation of model at starting point failed!\n"
1771 f"Starting values:\n{elem}\n\n"
1772 f"Logp initial evaluation results:\n{initial_eval}\n"
1773 "You can call `model.debug()` for more details."
1774 )
PyMC version information:
Running on a windows machine in a linux container
pymc installed via poetry(pypi).
using libopenblas
annotated-types 0.7.0
arviz 0.19.0
babel 2.16.0
blinker 1.4
build 1.2.2.post1
CacheControl 0.14.1
cachetools 5.5.0
certifi 2024.8.30
cfgv 3.4.0
charset-normalizer 3.4.0
cleo 2.1.0
click 8.1.7
cloudpickle 3.1.0
colorama 0.4.6
cons 0.4.6
contourpy 1.3.1
coverage 7.6.8
crashtest 0.4.1
cryptography 3.4.8
cycler 0.12.1
dbus-python 1.2.18
distlib 0.3.9
distro 1.7.0
distro-info 1.1+ubuntu0.2
dm-tree 0.1.8
dulwich 0.21.7
etuples 0.3.9
exceptiongroup 1.2.2
fastjsonschema 2.21.1
filelock 3.16.1
fonttools 4.55.1
ghp-import 2.1.0
graphviz 0.20.3
griffe 1.5.1
h5netcdf 1.4.1
h5py 3.12.1
httplib2 0.20.2
identify 2.6.3
idna 3.10
importlib_metadata 8.5.0
iniconfig 2.0.0
installer 0.7.0
jaraco.classes 3.4.0
jax 0.4.35
jaxlib 0.4.35
jeepney 0.7.1
Jinja2 3.1.4
joblib 1.4.2
keyring 24.3.1
kiwisolver 1.4.7
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
logical-unification 0.4.6
loguru 0.7.2
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.9.3
mdurl 0.1.2
mergedeep 1.3.4
miniKanren 1.0.3
mkdocs 1.6.1
mkdocs-autorefs 1.2.0
mkdocs-gen-files 0.5.0
mkdocs-get-deps 0.2.0
mkdocs-glightbox 0.4.0
mkdocs-literate-nav 0.6.1
mkdocs-material 9.5.47
mkdocs-material-extensions 1.3.1
mkdocs-section-index 0.3.9
mkdocstrings 0.26.2
mkdocstrings-python 1.12.2
ml_dtypes 0.5.0
mmm_v2 0.0.1 /workspaces/mmm_v2
more-itertools 8.10.0
msgpack 1.1.0
multimethod 1.10
multipledispatch 1.0.0
mypy-extensions 1.0.0
nodeenv 1.9.1
numpy 1.26.4
numpyro 0.15.3
oauthlib 3.2.0
opt_einsum 3.4.0
packaging 24.2
paginate 0.5.7
pandas 2.2.3
pandera 0.20.4
pathspec 0.12.1
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
pkginfo 1.12.0
platformdirs 4.3.6
pluggy 1.5.0
poetry 1.8.4
poetry-core 1.9.1
poetry-plugin-export 1.8.0
pre_commit 4.0.1
ptyprocess 0.7.0
pyarrow 18.1.0
pydantic 2.10.3
pydantic_core 2.27.1
Pygments 2.18.0
PyGObject 3.42.1
PyJWT 2.3.0
pymc 5.19.1
pymc-marketing 0.6.0
pymdown-extensions 10.12
pyparsing 3.2.0
pyproject_hooks 1.2.0
pytensor 2.26.4
pytest 8.3.4
pytest-cov 6.0.0
python-apt 2.4.0+ubuntu3
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
pyyaml_env_tag 0.1
RapidFuzz 3.10.1
regex 2024.11.6
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
ruff 0.8.1
scikit-learn 1.5.2
scipy 1.14.1
seaborn 0.13.2
SecretStorage 3.3.1
setuptools 75.6.0
shellingham 1.5.4
six 1.17.0
ssh-import-id 5.11
threadpoolctl 3.5.0
tomli 2.2.1
tomlkit 0.13.2
toolz 1.0.0
tqdm 4.67.1
trove-classifiers 2024.10.21.16
typeguard 4.4.1
typing_extensions 4.12.2
typing-inspect 0.9.0
tzdata 2024.2
unattended-upgrades 0.1
urllib3 2.2.3
virtualenv 20.28.0
wadllib 1.3.6
watchdog 6.0.0
wheel 0.38.4
wrapt 1.17.0
xarray 2024.9.0
xarray-einstats 0.8.0
zipp 3.21.0
Context for the issue:
I want to use f32 due to memory constraints, I've got a big model and dataset with a hierarchy that takes up 40+gb of ram when running using f32 so I'd need a huge box to go up to f64.
Maybe the optimisations in 5.19 make that problem moot? I've managed to get it running with f64 temporarily but I'm not sure if it'll be a long term solution