Skip to content

Fix unsupported dtype from PyMC sampler warnings with ClickHouseBackend #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
flake8 . --count --exit-zero --statistics
- name: Test with pytest
run: |
pytest --cov=./mcbackend --cov-report xml --cov-report term-missing mcbackend/
pytest -v --cov=./mcbackend --cov-report xml --cov-report term-missing mcbackend/
- name: Upload coverage
uses: codecov/codecov-action@v3
if: matrix.python-version == 3.9
Expand Down
47 changes: 33 additions & 14 deletions mcbackend/adapters/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

The only PyMC dependency is on the ``BaseTrace`` abstract base class.
"""
import base64
import pickle
from typing import Dict, List, Optional, Sequence, Tuple

import hagelkorn
Expand All @@ -26,9 +28,7 @@

def find_data(pmodel: Model) -> List[DataVariable]:
"""Extracts data variables from a model."""
observed_rvs = {
rv.tag.observations for rv in pmodel.observed_RVs if hasattr(rv.tag, "observations")
}
observed_rvs = {pmodel.rvs_to_values[rv] for rv in pmodel.observed_RVs}
dvars = []
# All data containers are named vars!
for name, var in pmodel.named_vars.items():
Expand All @@ -39,7 +39,7 @@ def find_data(pmodel: Model) -> List[DataVariable]:
dv.value = ndarray_from_numpy(var.get_value())
else:
continue
dv.dims = list(pmodel.RV_dims.get(name, []))
dv.dims = list(pmodel.named_vars_to_dims.get(name, []))
dv.is_observed = var in observed_rvs
dvars.append(dv)
return dvars
Expand Down Expand Up @@ -142,7 +142,9 @@ def setup(
name,
str(self.var_dtypes[name]),
list(self.var_shapes[name]),
dims=list(self.model.RV_dims[name]) if name in self.model.RV_dims else [],
dims=list(self.model.named_vars_to_dims[name])
if name in self.model.named_vars_to_dims
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a little unorthodox to use a ternary operator to compute the default arg. Couldn't it be just None and if it is None then it is computed? Just a nitpick though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I did this because it's in a list comprehension, and there's no way to conditionally NOT pass the kwarg

else [],
is_deterministic=(name not in free_rv_names),
)
for name in self.varnames
Expand All @@ -159,12 +161,16 @@ def setup(
self._stat_groups.append([])
for statname, dtype in names_dtypes.items():
sname = f"sampler_{s}__{statname}"
svar = Variable(
name=sname,
dtype=numpy.dtype(dtype).name,
# This 👇 is needed until PyMC provides shapes ahead of time.
undefined_ndim=True,
)
if statname == "warning":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning that if we would want to record warnings with hopsy we need to give the stat the name "warning"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for now this is only for stats named "warning", which also get special treatment by PyMC.

One could extend this to more arbitrary objects, but generally I think we should try to stick to standard data types and avoid objects, so I didn't want to add more flexibility than needed.

# SamplerWarnings will be pickled and stored as string!
svar = Variable(sname, "str")
else:
svar = Variable(
name=sname,
dtype=numpy.dtype(dtype).name,
# This 👇 is needed until PyMC provides shapes ahead of time.
undefined_ndim=True,
)
self._stat_groups[s].append((sname, statname))
sample_stats.append(svar)

Expand Down Expand Up @@ -197,8 +203,12 @@ def record(self, point, sampler_states=None):
for s, sts in enumerate(sampler_states):
for statname, sval in sts.items():
sname = f"sampler_{s}__{statname}"
stats[sname] = sval
# Make not whether this is a tuning iteration.
# Automatically pickle SamplerWarnings
if statname == "warning":
sval_bytes = pickle.dumps(sval)
sval = base64.encodebytes(sval_bytes).decode("ascii")
stats[sname] = numpy.asarray(sval)
# Make note whether this is a tuning iteration.
if statname == "tune":
stats["tune"] = sval

Expand All @@ -214,7 +224,16 @@ def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray:
def _get_stats(self, varname, burn=0, thin=1) -> numpy.ndarray:
if self._chain is None:
raise Exception("Trace setup was not completed. Call `.setup()` first.")
return self._chain.get_stats(varname)[burn::thin]
values = self._chain.get_stats(varname)[burn::thin]
if "warning" in varname:
objs = []
for v in values:
enc = v.encode("ascii")
str_ = base64.decodebytes(enc)
obj = pickle.loads(str_)
objs.append(obj)
values = numpy.array(objs, dtype=object)
return values

def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
if self._chain is None:
Expand Down
24 changes: 19 additions & 5 deletions mcbackend/backends/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
"int64": "Int64",
"float32": "Float32",
"float64": "Float64",
"bool": "UInt8",
"bool": "Bool",
"str": "String",
}


Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
self._client = client
# The following attributes belong to the batched insert mechanism.
# Inserting in batches is much faster than inserting single rows.
self._str_cols = set()
self._insert_query: str = ""
self._insert_queue: List[Dict[str, Any]] = []
self._last_insert = time.time()
Expand All @@ -166,12 +168,22 @@ def append(
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
):
stat = {f"__stat_{sname}": svals for sname, svals in (stats or {}).items()}
params: Dict[str, Any] = {"_draw_idx": self._draw_idx, **draw, **stat}
self._draw_idx += 1
params: Dict[str, numpy.ndarray] = {**draw, **stat}

# On first append create a query to be used for the batched insert
if not self._insert_query:
names = "`,`".join(params.keys())
self._insert_query = f"INSERT INTO {self.cid} (`{names}`) VALUES"
self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES"
self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name}

# Convert str ndarrays to lists
for col in self._str_cols:
params[col] = params[col].tolist()

# Queue up for insertion
params["_draw_idx"] = self._draw_idx
self._insert_queue.append(params)
self._draw_idx += 1

if (
len(self._insert_queue) >= self._insert_every
Expand Down Expand Up @@ -235,7 +247,9 @@ def _get_rows(
return numpy.array([], dtype=object)

# The unpacking must also account for non-rigid shapes
if is_rigid(nshape):
# and str-dtyped empty arrays default to fixed length 1 strings.
# The [None] list is slower, but more flexible in this regard.
if is_rigid(nshape) and dtype != "str":
assert nshape is not None
buffer = numpy.empty((draws, *nshape), dtype)
else:
Expand Down
12 changes: 9 additions & 3 deletions mcbackend/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
]:
for var in variables:
rigid = is_rigid(var.shape) and not var.undefined_ndim
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
rigid_dict[var.name] = rigid
if preallocate > 0 and rigid:
reserve = (preallocate, *var.shape)
Expand All @@ -88,13 +88,19 @@ def __len__(self) -> int:
return self._draw_idx

def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
return self._samples[var_name][: self._draw_idx][slc]
data = self._samples[var_name][: self._draw_idx][slc]
if self.variables[var_name].dtype == "str":
return numpy.array(data.tolist(), dtype=str)
return data

def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
return {vn: numpy.asarray(self._samples[vn][idx]) for vn in var_names}

def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
return self._stats[stat_name][: self._draw_idx][slc]
data = self._stats[stat_name][: self._draw_idx][slc]
if self.sample_stats[stat_name].dtype == "str":
return numpy.array(data.tolist(), dtype=str)
return data

def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
Expand Down
16 changes: 7 additions & 9 deletions mcbackend/test_adapter_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .adapters.pymc import TraceBackend, find_data
from .backends.clickhouse import ClickHouseBackend
from .test_backend_clickhouse import HAS_REAL_DB
from .test_backend_clickhouse import DB_KWARGS, HAS_REAL_DB

_log = logging.getLogger(__file__)

Expand Down Expand Up @@ -57,9 +57,9 @@ class TestPyMCAdapter:
def setup_method(self, method):
"""Initializes a fresh database just for this test method."""
self._db = "testing_" + hagelkorn.random()
self._client_main = clickhouse_driver.Client("localhost")
self._client_main = clickhouse_driver.Client(**DB_KWARGS)
self._client_main.execute(f"CREATE DATABASE {self._db};")
self._client = clickhouse_driver.Client("localhost", database=self._db)
self._client = clickhouse_driver.Client(**DB_KWARGS, database=self._db)
self.backend = ClickHouseBackend(self._client)
return

Expand Down Expand Up @@ -87,23 +87,21 @@ def wrapper(meta: RunMeta):
trace = TraceBackend(backend)
idata = pm.sample(
trace=trace,
tune=3,
draws=5,
tune=30,
draws=50,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the rng set somewhere to guarantee a warning?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

within these steps

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but with this few steps the NUTS is pretty much guaranteed to emit warnings.
And even if it occasionally does not, that should not break things.

chains=2,
cores=cores,
step=pm.Metropolis(),
discard_tuned_samples=False,
compute_convergence_checks=False,
)
if not len(args) == 1:
_log.warning("Run was initialized multiple times.")
rmeta = args[0]

# Chain lenghts after conversion
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 5
assert idata.posterior.dims["draw"] == 50
assert idata.warmup_posterior.dims["chain"] == 2
assert idata.warmup_posterior.dims["draw"] == 3
assert idata.warmup_posterior.dims["draw"] == 30

# Tracking of named variable dimensions
vars = {var.name: var for var in rmeta.variables}
Expand Down
51 changes: 40 additions & 11 deletions mcbackend/test_backend_clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import logging
import os
from datetime import datetime, timezone
from typing import Sequence, Tuple

Expand All @@ -21,7 +22,11 @@
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta

try:
client = clickhouse_driver.Client("localhost")
DB_HOST = os.environ.get("CLICKHOUSE_HOST", "localhost")
DB_PASS = os.environ.get("CLICKHOUSE_PASS", "")
DB_PORT = os.environ.get("CLICKHOUSE_PORT", 9000)
DB_KWARGS = dict(host=DB_HOST, port=DB_PORT, password=DB_PASS)
client = clickhouse_driver.Client(**DB_KWARGS)
client.execute("SHOW DATABASES;")
HAS_REAL_DB = True
except:
Expand Down Expand Up @@ -51,7 +56,7 @@ def fully_initialized(

@pytest.mark.skipif(
condition=not HAS_REAL_DB,
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
reason="Integration tests need a ClickHouse server.",
)
class TestClickHouseBackendInitialization:
"""This is separate because ``TestClickHouseBackend.setup_method`` depends on these things."""
Expand All @@ -63,12 +68,12 @@ def test_exceptions(self):

def test_backend_from_client_object(self):
db = "testing_" + hagelkorn.random()
_client_main = clickhouse_driver.Client("localhost")
_client_main = clickhouse_driver.Client(**DB_KWARGS)
_client_main.execute(f"CREATE DATABASE {db};")

try:
# When created from a client object, all chains share the client
backend = ClickHouseBackend(client=clickhouse_driver.Client("localhost", database=db))
backend = ClickHouseBackend(client=clickhouse_driver.Client(**DB_KWARGS, database=db))
assert callable(backend._client_fn)
run = backend.init_run(make_runmeta())
c1 = run.init_chain(0)
Expand All @@ -81,11 +86,11 @@ def test_backend_from_client_object(self):

def test_backend_from_client_function(self):
db = "testing_" + hagelkorn.random()
_client_main = clickhouse_driver.Client("localhost")
_client_main = clickhouse_driver.Client(**DB_KWARGS)
_client_main.execute(f"CREATE DATABASE {db};")

def client_fn():
return clickhouse_driver.Client("localhost", database=db)
return clickhouse_driver.Client(**DB_KWARGS, database=db)

try:
# When created from a client function, each chain has its own client
Expand All @@ -108,7 +113,7 @@ def client_fn():

@pytest.mark.skipif(
condition=not HAS_REAL_DB,
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
reason="Integration tests need a ClickHouse server.",
)
class TestClickHouseBackend(CheckBehavior, CheckPerformance):
cls_backend = ClickHouseBackend
Expand All @@ -118,11 +123,11 @@ class TestClickHouseBackend(CheckBehavior, CheckPerformance):
def setup_method(self, method):
"""Initializes a fresh database just for this test method."""
self._db = "testing_" + hagelkorn.random()
self._client_main = clickhouse_driver.Client("localhost")
self._client_main = clickhouse_driver.Client(**DB_KWARGS)
self._client_main.execute(f"CREATE DATABASE {self._db};")
self._client = clickhouse_driver.Client("localhost", database=self._db)
self._client = clickhouse_driver.Client(**DB_KWARGS, database=self._db)
self.backend = ClickHouseBackend(
client_fn=lambda: clickhouse_driver.Client("localhost", database=self._db)
client_fn=lambda: clickhouse_driver.Client(**DB_KWARGS, database=self._db)
)
return

Expand Down Expand Up @@ -197,7 +202,7 @@ def test_create_chain_table(self):
("scalar", "UInt16"),
("1D", "Array(Float32)"),
("3D", "Array(Array(Array(Float64)))"),
("__stat_accepted", "UInt8"),
("__stat_accepted", "Bool"),
]
pass

Expand Down Expand Up @@ -266,6 +271,30 @@ def test_insert_draw(self):
numpy.testing.assert_array_equal(v3, draw["v3"])
pass

@pytest.mark.xfail(
reason="Batch inserting assumes identical dict composition every time. See #74."
)
def test_insert_flaky_stats(self):
"""Tries to append stats that only sometimes have an entry for a stat."""
run, chains = fully_initialized(
self.backend,
RunMeta(
sample_stats=[
Variable("always", "int8"),
Variable("sometimes", "bool"),
]
),
)

chain = chains[0]
chain.append({}, dict(always=1, sometimes=True))
chain.append({}, dict(always=2))
chain._commit()

tuple(chain.get_stats("always")) == (1, 2)
assert tuple(chain.get_stats("sometimes")) == (True, None)
pass

def test_get_row_at(self):
run, chains = fully_initialized(
self.backend,
Expand Down
Loading