Skip to content

Commit eb639ca

Browse files
Add backend support for str dtypes
1 parent d885089 commit eb639ca

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

mcbackend/backends/clickhouse.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"float32": "Float32",
3131
"float64": "Float64",
3232
"bool": "UInt8",
33+
"str": "String",
3334
}
3435

3536

@@ -155,6 +156,7 @@ def __init__(
155156
self._client = client
156157
# The following attributes belong to the batched insert mechanism.
157158
# Inserting in batches is much faster than inserting single rows.
159+
self._str_cols = set()
158160
self._insert_query: str = ""
159161
self._insert_queue: List[Dict[str, Any]] = []
160162
self._last_insert = time.time()
@@ -166,12 +168,22 @@ def append(
166168
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
167169
):
168170
stat = {f"__stat_{sname}": svals for sname, svals in (stats or {}).items()}
169-
params: Dict[str, Any] = {"_draw_idx": self._draw_idx, **draw, **stat}
170-
self._draw_idx += 1
171+
params: Dict[str, numpy.ndarray] = {**draw, **stat}
172+
173+
# On first append create a query to be used for the batched insert
171174
if not self._insert_query:
172175
names = "`,`".join(params.keys())
173-
self._insert_query = f"INSERT INTO {self.cid} (`{names}`) VALUES"
176+
self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES"
177+
self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name}
178+
179+
# Convert str ndarrays to lists
180+
for col in self._str_cols:
181+
params[col] = params[col].tolist()
182+
183+
# Queue up for insertion
184+
params["_draw_idx"] = self._draw_idx
174185
self._insert_queue.append(params)
186+
self._draw_idx += 1
175187

176188
if (
177189
len(self._insert_queue) >= self._insert_every
@@ -235,7 +247,9 @@ def _get_rows(
235247
return numpy.array([], dtype=object)
236248

237249
# The unpacking must also account for non-rigid shapes
238-
if is_rigid(nshape):
250+
# and str-dtyped empty arrays default to fixed length 1 strings.
251+
# The [None] list is slower, but more flexible in this regard.
252+
if is_rigid(nshape) and dtype != "str":
239253
assert nshape is not None
240254
buffer = numpy.empty((draws, *nshape), dtype)
241255
else:

mcbackend/backends/numpy.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
6565
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
6666
]:
6767
for var in variables:
68-
rigid = is_rigid(var.shape) and not var.undefined_ndim
68+
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
6969
rigid_dict[var.name] = rigid
7070
if preallocate > 0 and rigid:
7171
reserve = (preallocate, *var.shape)
@@ -88,13 +88,19 @@ def __len__(self) -> int:
8888
return self._draw_idx
8989

9090
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
91-
return self._samples[var_name][: self._draw_idx][slc]
91+
data = self._samples[var_name][: self._draw_idx][slc]
92+
if self.variables[var_name].dtype == "str":
93+
return numpy.array(data.tolist(), dtype=str)
94+
return data
9295

9396
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
9497
return {vn: numpy.asarray(self._samples[vn][idx]) for vn in var_names}
9598

9699
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
97-
return self._stats[stat_name][: self._draw_idx][slc]
100+
data = self._stats[stat_name][: self._draw_idx][slc]
101+
if self.sample_stats[stat_name].dtype == "str":
102+
return numpy.array(data.tolist(), dtype=str)
103+
return data
98104

99105
def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
100106
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}

mcbackend/test_utils.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def make_runmeta(*, flexibility: bool = False, **kwargs) -> RunMeta:
2929
Variable("accepted", "bool", list((3,)), dims=["sampler"]),
3030
# But some stats may refer to the iteration.
3131
Variable("logp", "float64", []),
32+
# String dtypes may be used for more complex information
33+
Variable("message", "str"),
3234
],
3335
data=[
3436
DataVariable(
@@ -60,8 +62,16 @@ def make_draw(variables: Sequence[Variable]):
6062
)
6163
if "float" in var.dtype:
6264
draw[var.name] = numpy.random.normal(size=dshape).astype(var.dtype)
65+
elif var.dtype == "str":
66+
alphabet = tuple("abcdef#+*/'")
67+
words = [
68+
"".join(numpy.random.choice(alphabet, size=numpy.random.randint(3, 10)))
69+
for _ in range(int(numpy.prod(dshape)))
70+
]
71+
draw[var.name] = numpy.array(words, dtype=var.dtype).reshape(dshape)
6372
else:
6473
draw[var.name] = numpy.random.randint(low=0, high=100, size=dshape).astype(var.dtype)
74+
assert draw[var.name].shape == dshape
6575
return draw
6676

6777

@@ -149,7 +159,7 @@ def test__append_get_with_changelings(self, with_stats):
149159
expected = [draw[var.name] for draw in draws]
150160
actual = chain.get_draws(var.name)
151161
assert isinstance(actual, numpy.ndarray)
152-
if var.name == "changeling":
162+
if not is_rigid(var.shape) or var.dtype == "str":
153163
# Non-ridid variables are returned as object-arrays.
154164
assert actual.shape == (len(expected),)
155165
assert actual.dtype == object
@@ -166,9 +176,13 @@ def test__append_get_with_changelings(self, with_stats):
166176
expected = [stat[var.name] for stat in stats]
167177
actual = chain.get_stats(var.name)
168178
assert isinstance(actual, numpy.ndarray)
169-
if is_rigid(var.shape):
179+
if var.dtype == "str":
170180
assert tuple(actual.shape) == tuple(numpy.shape(expected))
171-
assert actual.dtype == var.dtype
181+
# String dtypes have strange names
182+
assert "str" in actual.dtype.name
183+
elif is_rigid(var.shape):
184+
assert tuple(actual.shape) == tuple(numpy.shape(expected))
185+
assert actual.dtype.name == var.dtype
172186
numpy.testing.assert_array_equal(actual, expected)
173187
else:
174188
# Non-ridid variables are returned as object-arrays.
@@ -200,7 +214,7 @@ def test__get_slicing(self, slc: slice):
200214
# "A" are just numbers to make diagnosis easier.
201215
# "B" are dynamically shaped to cover the edge cases.
202216
rmeta = RunMeta(
203-
variables=[Variable("A", "uint8")],
217+
variables=[Variable("A", "uint8"), Variable("M", "str", [2, 3])],
204218
sample_stats=[Variable("B", "uint8", [2, 0])],
205219
data=[],
206220
)
@@ -209,7 +223,7 @@ def test__get_slicing(self, slc: slice):
209223

210224
# Generate draws and add them to the chain
211225
N = 20
212-
draws = [dict(A=n) for n in range(N)]
226+
draws = [make_draw(rmeta.variables) for n in range(N)]
213227
stats = [make_draw(rmeta.sample_stats) for n in range(N)]
214228
for d, s in zip(draws, stats):
215229
chain.append(d, s)
@@ -218,12 +232,25 @@ def test__get_slicing(self, slc: slice):
218232
# slc=None in this test means "don't pass it".
219233
# The implementations should default to slc=slice(None, None, None).
220234
kwargs = dict(slc=slc) if slc is not None else {}
221-
act_draws = chain.get_draws("A", **kwargs)
235+
act_draws_A = chain.get_draws("A", **kwargs)
236+
act_draws_M = chain.get_draws("M", **kwargs)
222237
act_stats = chain.get_stats("B", **kwargs)
223-
expected_draws = [d["A"] for d in draws][slc or slice(None, None, None)]
238+
expected_draws_A = [d["A"] for d in draws][slc or slice(None, None, None)]
239+
expected_draws_M = [d["M"] for d in draws][slc or slice(None, None, None)]
224240
expected_stats = [s["B"] for s in stats][slc or slice(None, None, None)]
241+
225242
# Variable "A" has a rigid shape
226-
numpy.testing.assert_array_equal(act_draws, expected_draws)
243+
if expected_draws_A:
244+
numpy.testing.assert_array_equal(act_draws_A, expected_draws_A)
245+
else:
246+
assert len(act_draws_A) == 0
247+
248+
# Variable "M" is a string matrix
249+
if expected_draws_M:
250+
numpy.testing.assert_array_equal(act_draws_M, expected_draws_M)
251+
else:
252+
assert len(act_draws_M) == 0
253+
227254
# Stat "B" is dynamically shaped, which means we're dealing with
228255
# dtype=object arrays. These must be checked elementwise.
229256
assert len(act_stats) == len(expected_stats)
@@ -256,6 +283,7 @@ def test__to_inferencedata(self):
256283
sample_stats=[
257284
Variable("tune", "bool"),
258285
Variable("sampler_0__logp", "float32"),
286+
Variable("warning", "str"),
259287
],
260288
)
261289
run = self.backend.init_run(rmeta)

0 commit comments

Comments
 (0)