Skip to content

Commit 408cca2

Browse files
Merge pull request #75 from michaelosthege/issue-73
Fix unsupported dtype from PyMC sampler warnings with `ClickHouseBackend`
2 parents 15b6082 + dfd4d38 commit 408cca2

File tree

7 files changed

+145
-51
lines changed

7 files changed

+145
-51
lines changed

.github/workflows/pipeline.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
flake8 . --count --exit-zero --statistics
3434
- name: Test with pytest
3535
run: |
36-
pytest --cov=./mcbackend --cov-report xml --cov-report term-missing mcbackend/
36+
pytest -v --cov=./mcbackend --cov-report xml --cov-report term-missing mcbackend/
3737
- name: Upload coverage
3838
uses: codecov/codecov-action@v3
3939
if: matrix.python-version == 3.9

mcbackend/adapters/pymc.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
The only PyMC dependency is on the ``BaseTrace`` abstract base class.
55
"""
6+
import base64
7+
import pickle
68
from typing import Dict, List, Optional, Sequence, Tuple
79

810
import hagelkorn
@@ -26,9 +28,7 @@
2628

2729
def find_data(pmodel: Model) -> List[DataVariable]:
2830
"""Extracts data variables from a model."""
29-
observed_rvs = {
30-
rv.tag.observations for rv in pmodel.observed_RVs if hasattr(rv.tag, "observations")
31-
}
31+
observed_rvs = {pmodel.rvs_to_values[rv] for rv in pmodel.observed_RVs}
3232
dvars = []
3333
# All data containers are named vars!
3434
for name, var in pmodel.named_vars.items():
@@ -39,7 +39,7 @@ def find_data(pmodel: Model) -> List[DataVariable]:
3939
dv.value = ndarray_from_numpy(var.get_value())
4040
else:
4141
continue
42-
dv.dims = list(pmodel.RV_dims.get(name, []))
42+
dv.dims = list(pmodel.named_vars_to_dims.get(name, []))
4343
dv.is_observed = var in observed_rvs
4444
dvars.append(dv)
4545
return dvars
@@ -142,7 +142,9 @@ def setup(
142142
name,
143143
str(self.var_dtypes[name]),
144144
list(self.var_shapes[name]),
145-
dims=list(self.model.RV_dims[name]) if name in self.model.RV_dims else [],
145+
dims=list(self.model.named_vars_to_dims[name])
146+
if name in self.model.named_vars_to_dims
147+
else [],
146148
is_deterministic=(name not in free_rv_names),
147149
)
148150
for name in self.varnames
@@ -159,12 +161,16 @@ def setup(
159161
self._stat_groups.append([])
160162
for statname, dtype in names_dtypes.items():
161163
sname = f"sampler_{s}__{statname}"
162-
svar = Variable(
163-
name=sname,
164-
dtype=numpy.dtype(dtype).name,
165-
# This 👇 is needed until PyMC provides shapes ahead of time.
166-
undefined_ndim=True,
167-
)
164+
if statname == "warning":
165+
# SamplerWarnings will be pickled and stored as string!
166+
svar = Variable(sname, "str")
167+
else:
168+
svar = Variable(
169+
name=sname,
170+
dtype=numpy.dtype(dtype).name,
171+
# This 👇 is needed until PyMC provides shapes ahead of time.
172+
undefined_ndim=True,
173+
)
168174
self._stat_groups[s].append((sname, statname))
169175
sample_stats.append(svar)
170176

@@ -197,8 +203,12 @@ def record(self, point, sampler_states=None):
197203
for s, sts in enumerate(sampler_states):
198204
for statname, sval in sts.items():
199205
sname = f"sampler_{s}__{statname}"
200-
stats[sname] = sval
201-
# Make not whether this is a tuning iteration.
206+
# Automatically pickle SamplerWarnings
207+
if statname == "warning":
208+
sval_bytes = pickle.dumps(sval)
209+
sval = base64.encodebytes(sval_bytes).decode("ascii")
210+
stats[sname] = numpy.asarray(sval)
211+
# Make note whether this is a tuning iteration.
202212
if statname == "tune":
203213
stats["tune"] = sval
204214

@@ -214,7 +224,16 @@ def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray:
214224
def _get_stats(self, varname, burn=0, thin=1) -> numpy.ndarray:
215225
if self._chain is None:
216226
raise Exception("Trace setup was not completed. Call `.setup()` first.")
217-
return self._chain.get_stats(varname)[burn::thin]
227+
values = self._chain.get_stats(varname)[burn::thin]
228+
if "warning" in varname:
229+
objs = []
230+
for v in values:
231+
enc = v.encode("ascii")
232+
str_ = base64.decodebytes(enc)
233+
obj = pickle.loads(str_)
234+
objs.append(obj)
235+
values = numpy.array(objs, dtype=object)
236+
return values
218237

219238
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
220239
if self._chain is None:

mcbackend/backends/clickhouse.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
"int64": "Int64",
3030
"float32": "Float32",
3131
"float64": "Float64",
32-
"bool": "UInt8",
32+
"bool": "Bool",
33+
"str": "String",
3334
}
3435

3536

@@ -155,6 +156,7 @@ def __init__(
155156
self._client = client
156157
# The following attributes belong to the batched insert mechanism.
157158
# Inserting in batches is much faster than inserting single rows.
159+
self._str_cols = set()
158160
self._insert_query: str = ""
159161
self._insert_queue: List[Dict[str, Any]] = []
160162
self._last_insert = time.time()
@@ -166,12 +168,22 @@ def append(
166168
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
167169
):
168170
stat = {f"__stat_{sname}": svals for sname, svals in (stats or {}).items()}
169-
params: Dict[str, Any] = {"_draw_idx": self._draw_idx, **draw, **stat}
170-
self._draw_idx += 1
171+
params: Dict[str, numpy.ndarray] = {**draw, **stat}
172+
173+
# On first append create a query to be used for the batched insert
171174
if not self._insert_query:
172175
names = "`,`".join(params.keys())
173-
self._insert_query = f"INSERT INTO {self.cid} (`{names}`) VALUES"
176+
self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES"
177+
self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name}
178+
179+
# Convert str ndarrays to lists
180+
for col in self._str_cols:
181+
params[col] = params[col].tolist()
182+
183+
# Queue up for insertion
184+
params["_draw_idx"] = self._draw_idx
174185
self._insert_queue.append(params)
186+
self._draw_idx += 1
175187

176188
if (
177189
len(self._insert_queue) >= self._insert_every
@@ -235,7 +247,9 @@ def _get_rows(
235247
return numpy.array([], dtype=object)
236248

237249
# The unpacking must also account for non-rigid shapes
238-
if is_rigid(nshape):
250+
# and str-dtyped empty arrays default to fixed length 1 strings.
251+
# The [None] list is slower, but more flexible in this regard.
252+
if is_rigid(nshape) and dtype != "str":
239253
assert nshape is not None
240254
buffer = numpy.empty((draws, *nshape), dtype)
241255
else:

mcbackend/backends/numpy.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
6565
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
6666
]:
6767
for var in variables:
68-
rigid = is_rigid(var.shape) and not var.undefined_ndim
68+
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
6969
rigid_dict[var.name] = rigid
7070
if preallocate > 0 and rigid:
7171
reserve = (preallocate, *var.shape)
@@ -88,13 +88,19 @@ def __len__(self) -> int:
8888
return self._draw_idx
8989

9090
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
91-
return self._samples[var_name][: self._draw_idx][slc]
91+
data = self._samples[var_name][: self._draw_idx][slc]
92+
if self.variables[var_name].dtype == "str":
93+
return numpy.array(data.tolist(), dtype=str)
94+
return data
9295

9396
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
9497
return {vn: numpy.asarray(self._samples[vn][idx]) for vn in var_names}
9598

9699
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
97-
return self._stats[stat_name][: self._draw_idx][slc]
100+
data = self._stats[stat_name][: self._draw_idx][slc]
101+
if self.sample_stats[stat_name].dtype == "str":
102+
return numpy.array(data.tolist(), dtype=str)
103+
return data
98104

99105
def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
100106
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}

mcbackend/test_adapter_pymc.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .adapters.pymc import TraceBackend, find_data
1313
from .backends.clickhouse import ClickHouseBackend
14-
from .test_backend_clickhouse import HAS_REAL_DB
14+
from .test_backend_clickhouse import DB_KWARGS, HAS_REAL_DB
1515

1616
_log = logging.getLogger(__file__)
1717

@@ -57,9 +57,9 @@ class TestPyMCAdapter:
5757
def setup_method(self, method):
5858
"""Initializes a fresh database just for this test method."""
5959
self._db = "testing_" + hagelkorn.random()
60-
self._client_main = clickhouse_driver.Client("localhost")
60+
self._client_main = clickhouse_driver.Client(**DB_KWARGS)
6161
self._client_main.execute(f"CREATE DATABASE {self._db};")
62-
self._client = clickhouse_driver.Client("localhost", database=self._db)
62+
self._client = clickhouse_driver.Client(**DB_KWARGS, database=self._db)
6363
self.backend = ClickHouseBackend(self._client)
6464
return
6565

@@ -87,23 +87,21 @@ def wrapper(meta: RunMeta):
8787
trace = TraceBackend(backend)
8888
idata = pm.sample(
8989
trace=trace,
90-
tune=3,
91-
draws=5,
90+
tune=30,
91+
draws=50,
9292
chains=2,
9393
cores=cores,
94-
step=pm.Metropolis(),
9594
discard_tuned_samples=False,
96-
compute_convergence_checks=False,
9795
)
9896
if not len(args) == 1:
9997
_log.warning("Run was initialized multiple times.")
10098
rmeta = args[0]
10199

102100
# Chain lenghts after conversion
103101
assert idata.posterior.dims["chain"] == 2
104-
assert idata.posterior.dims["draw"] == 5
102+
assert idata.posterior.dims["draw"] == 50
105103
assert idata.warmup_posterior.dims["chain"] == 2
106-
assert idata.warmup_posterior.dims["draw"] == 3
104+
assert idata.warmup_posterior.dims["draw"] == 30
107105

108106
# Tracking of named variable dimensions
109107
vars = {var.name: var for var in rmeta.variables}

mcbackend/test_backend_clickhouse.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import logging
3+
import os
34
from datetime import datetime, timezone
45
from typing import Sequence, Tuple
56

@@ -21,7 +22,11 @@
2122
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta
2223

2324
try:
24-
client = clickhouse_driver.Client("localhost")
25+
DB_HOST = os.environ.get("CLICKHOUSE_HOST", "localhost")
26+
DB_PASS = os.environ.get("CLICKHOUSE_PASS", "")
27+
DB_PORT = os.environ.get("CLICKHOUSE_PORT", 9000)
28+
DB_KWARGS = dict(host=DB_HOST, port=DB_PORT, password=DB_PASS)
29+
client = clickhouse_driver.Client(**DB_KWARGS)
2530
client.execute("SHOW DATABASES;")
2631
HAS_REAL_DB = True
2732
except:
@@ -51,7 +56,7 @@ def fully_initialized(
5156

5257
@pytest.mark.skipif(
5358
condition=not HAS_REAL_DB,
54-
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
59+
reason="Integration tests need a ClickHouse server.",
5560
)
5661
class TestClickHouseBackendInitialization:
5762
"""This is separate because ``TestClickHouseBackend.setup_method`` depends on these things."""
@@ -63,12 +68,12 @@ def test_exceptions(self):
6368

6469
def test_backend_from_client_object(self):
6570
db = "testing_" + hagelkorn.random()
66-
_client_main = clickhouse_driver.Client("localhost")
71+
_client_main = clickhouse_driver.Client(**DB_KWARGS)
6772
_client_main.execute(f"CREATE DATABASE {db};")
6873

6974
try:
7075
# When created from a client object, all chains share the client
71-
backend = ClickHouseBackend(client=clickhouse_driver.Client("localhost", database=db))
76+
backend = ClickHouseBackend(client=clickhouse_driver.Client(**DB_KWARGS, database=db))
7277
assert callable(backend._client_fn)
7378
run = backend.init_run(make_runmeta())
7479
c1 = run.init_chain(0)
@@ -81,11 +86,11 @@ def test_backend_from_client_object(self):
8186

8287
def test_backend_from_client_function(self):
8388
db = "testing_" + hagelkorn.random()
84-
_client_main = clickhouse_driver.Client("localhost")
89+
_client_main = clickhouse_driver.Client(**DB_KWARGS)
8590
_client_main.execute(f"CREATE DATABASE {db};")
8691

8792
def client_fn():
88-
return clickhouse_driver.Client("localhost", database=db)
93+
return clickhouse_driver.Client(**DB_KWARGS, database=db)
8994

9095
try:
9196
# When created from a client function, each chain has its own client
@@ -108,7 +113,7 @@ def client_fn():
108113

109114
@pytest.mark.skipif(
110115
condition=not HAS_REAL_DB,
111-
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
116+
reason="Integration tests need a ClickHouse server.",
112117
)
113118
class TestClickHouseBackend(CheckBehavior, CheckPerformance):
114119
cls_backend = ClickHouseBackend
@@ -118,11 +123,11 @@ class TestClickHouseBackend(CheckBehavior, CheckPerformance):
118123
def setup_method(self, method):
119124
"""Initializes a fresh database just for this test method."""
120125
self._db = "testing_" + hagelkorn.random()
121-
self._client_main = clickhouse_driver.Client("localhost")
126+
self._client_main = clickhouse_driver.Client(**DB_KWARGS)
122127
self._client_main.execute(f"CREATE DATABASE {self._db};")
123-
self._client = clickhouse_driver.Client("localhost", database=self._db)
128+
self._client = clickhouse_driver.Client(**DB_KWARGS, database=self._db)
124129
self.backend = ClickHouseBackend(
125-
client_fn=lambda: clickhouse_driver.Client("localhost", database=self._db)
130+
client_fn=lambda: clickhouse_driver.Client(**DB_KWARGS, database=self._db)
126131
)
127132
return
128133

@@ -197,7 +202,7 @@ def test_create_chain_table(self):
197202
("scalar", "UInt16"),
198203
("1D", "Array(Float32)"),
199204
("3D", "Array(Array(Array(Float64)))"),
200-
("__stat_accepted", "UInt8"),
205+
("__stat_accepted", "Bool"),
201206
]
202207
pass
203208

@@ -266,6 +271,30 @@ def test_insert_draw(self):
266271
numpy.testing.assert_array_equal(v3, draw["v3"])
267272
pass
268273

274+
@pytest.mark.xfail(
275+
reason="Batch inserting assumes identical dict composition every time. See #74."
276+
)
277+
def test_insert_flaky_stats(self):
278+
"""Tries to append stats that only sometimes have an entry for a stat."""
279+
run, chains = fully_initialized(
280+
self.backend,
281+
RunMeta(
282+
sample_stats=[
283+
Variable("always", "int8"),
284+
Variable("sometimes", "bool"),
285+
]
286+
),
287+
)
288+
289+
chain = chains[0]
290+
chain.append({}, dict(always=1, sometimes=True))
291+
chain.append({}, dict(always=2))
292+
chain._commit()
293+
294+
tuple(chain.get_stats("always")) == (1, 2)
295+
assert tuple(chain.get_stats("sometimes")) == (True, None)
296+
pass
297+
269298
def test_get_row_at(self):
270299
run, chains = fully_initialized(
271300
self.backend,

0 commit comments

Comments
 (0)