Skip to content

Commit 03f7fd7

Browse files
Add backend support for str dtypes
1 parent d885089 commit 03f7fd7

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

mcbackend/backends/clickhouse.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"float32": "Float32",
3131
"float64": "Float64",
3232
"bool": "UInt8",
33+
"str": "String",
3334
}
3435

3536

@@ -155,6 +156,7 @@ def __init__(
155156
self._client = client
156157
# The following attributes belong to the batched insert mechanism.
157158
# Inserting in batches is much faster than inserting single rows.
159+
self._str_cols = set()
158160
self._insert_query: str = ""
159161
self._insert_queue: List[Dict[str, Any]] = []
160162
self._last_insert = time.time()
@@ -166,12 +168,22 @@ def append(
166168
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
167169
):
168170
stat = {f"__stat_{sname}": svals for sname, svals in (stats or {}).items()}
169-
params: Dict[str, Any] = {"_draw_idx": self._draw_idx, **draw, **stat}
170-
self._draw_idx += 1
171+
params: Dict[str, numpy.ndarray] = {**draw, **stat}
172+
173+
# On first append create a query to be used for the batched insert
171174
if not self._insert_query:
172175
names = "`,`".join(params.keys())
173-
self._insert_query = f"INSERT INTO {self.cid} (`{names}`) VALUES"
176+
self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES"
177+
self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name}
178+
179+
# Convert str ndarrays to lists
180+
for col in self._str_cols:
181+
params[col] = params[col].tolist()
182+
183+
# Queue up for insertion
184+
params["_draw_idx"] = self._draw_idx
174185
self._insert_queue.append(params)
186+
self._draw_idx += 1
175187

176188
if (
177189
len(self._insert_queue) >= self._insert_every

mcbackend/backends/numpy.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
6565
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
6666
]:
6767
for var in variables:
68-
rigid = is_rigid(var.shape) and not var.undefined_ndim
68+
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
6969
rigid_dict[var.name] = rigid
7070
if preallocate > 0 and rigid:
7171
reserve = (preallocate, *var.shape)
@@ -88,13 +88,19 @@ def __len__(self) -> int:
8888
return self._draw_idx
8989

9090
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
91-
return self._samples[var_name][: self._draw_idx][slc]
91+
data = self._samples[var_name][: self._draw_idx][slc]
92+
if self.variables[var_name].dtype == "str":
93+
return data.astype(str)
94+
return data
9295

9396
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
9497
return {vn: numpy.asarray(self._samples[vn][idx]) for vn in var_names}
9598

9699
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
97-
return self._stats[stat_name][: self._draw_idx][slc]
100+
data = self._stats[stat_name][: self._draw_idx][slc]
101+
if self.sample_stats[stat_name].dtype == "str":
102+
return data.astype(str)
103+
return data
98104

99105
def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
100106
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}

mcbackend/test_utils.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def make_runmeta(*, flexibility: bool = False, **kwargs) -> RunMeta:
2929
Variable("accepted", "bool", list((3,)), dims=["sampler"]),
3030
# But some stats may refer to the iteration.
3131
Variable("logp", "float64", []),
32+
# String dtypes may be used for more complex information
33+
Variable("message", "str"),
3234
],
3335
data=[
3436
DataVariable(
@@ -60,8 +62,13 @@ def make_draw(variables: Sequence[Variable]):
6062
)
6163
if "float" in var.dtype:
6264
draw[var.name] = numpy.random.normal(size=dshape).astype(var.dtype)
65+
elif var.dtype == "str":
66+
alphabet = tuple("abcdef#+*/'")
67+
words = ["".join(numpy.random.choice(alphabet, size=numpy.random.randint(3, 10)))]
68+
draw[var.name] = numpy.array(words).reshape(dshape)
6369
else:
6470
draw[var.name] = numpy.random.randint(low=0, high=100, size=dshape).astype(var.dtype)
71+
assert draw[var.name].shape == dshape
6572
return draw
6673

6774

@@ -149,7 +156,7 @@ def test__append_get_with_changelings(self, with_stats):
149156
expected = [draw[var.name] for draw in draws]
150157
actual = chain.get_draws(var.name)
151158
assert isinstance(actual, numpy.ndarray)
152-
if var.name == "changeling":
159+
if not is_rigid(var.shape) or var.dtype == "str":
153160
# Non-ridid variables are returned as object-arrays.
154161
assert actual.shape == (len(expected),)
155162
assert actual.dtype == object
@@ -166,9 +173,13 @@ def test__append_get_with_changelings(self, with_stats):
166173
expected = [stat[var.name] for stat in stats]
167174
actual = chain.get_stats(var.name)
168175
assert isinstance(actual, numpy.ndarray)
169-
if is_rigid(var.shape):
176+
if var.dtype == "str":
170177
assert tuple(actual.shape) == tuple(numpy.shape(expected))
171-
assert actual.dtype == var.dtype
178+
# String dtypes have strange names
179+
assert "str" in actual.dtype.name
180+
elif is_rigid(var.shape):
181+
assert tuple(actual.shape) == tuple(numpy.shape(expected))
182+
assert actual.dtype.name == var.dtype
172183
numpy.testing.assert_array_equal(actual, expected)
173184
else:
174185
# Non-ridid variables are returned as object-arrays.
@@ -200,7 +211,7 @@ def test__get_slicing(self, slc: slice):
200211
# "A" are just numbers to make diagnosis easier.
201212
# "B" are dynamically shaped to cover the edge cases.
202213
rmeta = RunMeta(
203-
variables=[Variable("A", "uint8")],
214+
variables=[Variable("A", "uint8"), Variable("M", "str", [2, 3])],
204215
sample_stats=[Variable("B", "uint8", [2, 0])],
205216
data=[],
206217
)
@@ -209,7 +220,7 @@ def test__get_slicing(self, slc: slice):
209220

210221
# Generate draws and add them to the chain
211222
N = 20
212-
draws = [dict(A=n) for n in range(N)]
223+
draws = [dict(A=numpy.array(n)) for n in range(N)]
213224
stats = [make_draw(rmeta.sample_stats) for n in range(N)]
214225
for d, s in zip(draws, stats):
215226
chain.append(d, s)
@@ -222,8 +233,13 @@ def test__get_slicing(self, slc: slice):
222233
act_stats = chain.get_stats("B", **kwargs)
223234
expected_draws = [d["A"] for d in draws][slc or slice(None, None, None)]
224235
expected_stats = [s["B"] for s in stats][slc or slice(None, None, None)]
236+
225237
# Variable "A" has a rigid shape
226238
numpy.testing.assert_array_equal(act_draws, expected_draws)
239+
240+
# Variable "M" is a string matrix
241+
numpy.testing.assert_array_equal(act_draws, expected_draws)
242+
227243
# Stat "B" is dynamically shaped, which means we're dealing with
228244
# dtype=object arrays. These must be checked elementwise.
229245
assert len(act_stats) == len(expected_stats)
@@ -256,6 +272,7 @@ def test__to_inferencedata(self):
256272
sample_stats=[
257273
Variable("tune", "bool"),
258274
Variable("sampler_0__logp", "float32"),
275+
Variable("warning", "str"),
259276
],
260277
)
261278
run = self.backend.init_run(rmeta)

0 commit comments

Comments
 (0)