-
Notifications
You must be signed in to change notification settings - Fork 6
Fix unsupported dtype from PyMC sampler warnings with ClickHouseBackend
#75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b647b66
16400d9
d885089
eb639ca
a7809c6
1927579
dfd4d38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
|
||
The only PyMC dependency is on the ``BaseTrace`` abstract base class. | ||
""" | ||
import base64 | ||
import pickle | ||
from typing import Dict, List, Optional, Sequence, Tuple | ||
|
||
import hagelkorn | ||
|
@@ -26,9 +28,7 @@ | |
|
||
def find_data(pmodel: Model) -> List[DataVariable]: | ||
"""Extracts data variables from a model.""" | ||
observed_rvs = { | ||
rv.tag.observations for rv in pmodel.observed_RVs if hasattr(rv.tag, "observations") | ||
} | ||
observed_rvs = {pmodel.rvs_to_values[rv] for rv in pmodel.observed_RVs} | ||
dvars = [] | ||
# All data containers are named vars! | ||
for name, var in pmodel.named_vars.items(): | ||
|
@@ -39,7 +39,7 @@ def find_data(pmodel: Model) -> List[DataVariable]: | |
dv.value = ndarray_from_numpy(var.get_value()) | ||
else: | ||
continue | ||
dv.dims = list(pmodel.RV_dims.get(name, [])) | ||
dv.dims = list(pmodel.named_vars_to_dims.get(name, [])) | ||
dv.is_observed = var in observed_rvs | ||
dvars.append(dv) | ||
return dvars | ||
|
@@ -142,7 +142,9 @@ def setup( | |
name, | ||
str(self.var_dtypes[name]), | ||
list(self.var_shapes[name]), | ||
dims=list(self.model.RV_dims[name]) if name in self.model.RV_dims else [], | ||
dims=list(self.model.named_vars_to_dims[name]) | ||
if name in self.model.named_vars_to_dims | ||
else [], | ||
is_deterministic=(name not in free_rv_names), | ||
) | ||
for name in self.varnames | ||
|
@@ -159,12 +161,16 @@ def setup( | |
self._stat_groups.append([]) | ||
for statname, dtype in names_dtypes.items(): | ||
sname = f"sampler_{s}__{statname}" | ||
svar = Variable( | ||
name=sname, | ||
dtype=numpy.dtype(dtype).name, | ||
# This 👇 is needed until PyMC provides shapes ahead of time. | ||
undefined_ndim=True, | ||
) | ||
if statname == "warning": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meaning that if we would want to record warnings with hopsy we need to give the stat the name "warning"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, for now this is only for stats named "warning", which also get special treatment by PyMC. One could extend this to more arbitrary objects, but generally I think we should try to stick to standard data types and avoid objects, so I didn't want to add more flexibility than needed. |
||
# SamplerWarnings will be pickled and stored as string! | ||
svar = Variable(sname, "str") | ||
else: | ||
svar = Variable( | ||
name=sname, | ||
dtype=numpy.dtype(dtype).name, | ||
# This 👇 is needed until PyMC provides shapes ahead of time. | ||
undefined_ndim=True, | ||
) | ||
self._stat_groups[s].append((sname, statname)) | ||
sample_stats.append(svar) | ||
|
||
|
@@ -197,8 +203,12 @@ def record(self, point, sampler_states=None): | |
for s, sts in enumerate(sampler_states): | ||
for statname, sval in sts.items(): | ||
sname = f"sampler_{s}__{statname}" | ||
stats[sname] = sval | ||
# Make not whether this is a tuning iteration. | ||
# Automatically pickle SamplerWarnings | ||
if statname == "warning": | ||
sval_bytes = pickle.dumps(sval) | ||
sval = base64.encodebytes(sval_bytes).decode("ascii") | ||
stats[sname] = numpy.asarray(sval) | ||
# Make note whether this is a tuning iteration. | ||
if statname == "tune": | ||
stats["tune"] = sval | ||
|
||
|
@@ -214,7 +224,16 @@ def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray: | |
def _get_stats(self, varname, burn=0, thin=1) -> numpy.ndarray: | ||
if self._chain is None: | ||
raise Exception("Trace setup was not completed. Call `.setup()` first.") | ||
return self._chain.get_stats(varname)[burn::thin] | ||
values = self._chain.get_stats(varname)[burn::thin] | ||
if "warning" in varname: | ||
objs = [] | ||
for v in values: | ||
enc = v.encode("ascii") | ||
str_ = base64.decodebytes(enc) | ||
obj = pickle.loads(str_) | ||
objs.append(obj) | ||
values = numpy.array(objs, dtype=object) | ||
return values | ||
|
||
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin): | ||
if self._chain is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
from .adapters.pymc import TraceBackend, find_data | ||
from .backends.clickhouse import ClickHouseBackend | ||
from .test_backend_clickhouse import HAS_REAL_DB | ||
from .test_backend_clickhouse import DB_KWARGS, HAS_REAL_DB | ||
|
||
_log = logging.getLogger(__file__) | ||
|
||
|
@@ -57,9 +57,9 @@ class TestPyMCAdapter: | |
def setup_method(self, method): | ||
"""Initializes a fresh database just for this test method.""" | ||
self._db = "testing_" + hagelkorn.random() | ||
self._client_main = clickhouse_driver.Client("localhost") | ||
self._client_main = clickhouse_driver.Client(**DB_KWARGS) | ||
self._client_main.execute(f"CREATE DATABASE {self._db};") | ||
self._client = clickhouse_driver.Client("localhost", database=self._db) | ||
self._client = clickhouse_driver.Client(**DB_KWARGS, database=self._db) | ||
self.backend = ClickHouseBackend(self._client) | ||
return | ||
|
||
|
@@ -87,23 +87,21 @@ def wrapper(meta: RunMeta): | |
trace = TraceBackend(backend) | ||
idata = pm.sample( | ||
trace=trace, | ||
tune=3, | ||
draws=5, | ||
tune=30, | ||
draws=50, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the rng set somewhere to guarantee a warning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. within these steps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, but with this few steps the NUTS is pretty much guaranteed to emit warnings. |
||
chains=2, | ||
cores=cores, | ||
step=pm.Metropolis(), | ||
discard_tuned_samples=False, | ||
compute_convergence_checks=False, | ||
) | ||
if not len(args) == 1: | ||
_log.warning("Run was initialized multiple times.") | ||
rmeta = args[0] | ||
|
||
# Chain lenghts after conversion | ||
assert idata.posterior.dims["chain"] == 2 | ||
assert idata.posterior.dims["draw"] == 5 | ||
assert idata.posterior.dims["draw"] == 50 | ||
assert idata.warmup_posterior.dims["chain"] == 2 | ||
assert idata.warmup_posterior.dims["draw"] == 3 | ||
assert idata.warmup_posterior.dims["draw"] == 30 | ||
|
||
# Tracking of named variable dimensions | ||
vars = {var.name: var for var in rmeta.variables} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a little unorthodox to use a ternary operator to compute the default arg. Couldn't it be just None and if it is None then it is computed? Just a nitpick though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I did this because it's in a list comprehension, and there's no way to conditionally NOT pass the kwarg