Skip to content

Refactor ClickHouseChain for efficient slice selects #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions mcbackend/backends/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,46 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta:
return client.execute(query)


def where_slice(slc: slice, imax: int, col="_draw_idx") -> Tuple[str, bool]:
"""Creates a WHERE clause to select rows according to a Python slice.

Parameters
----------
slc : slice
A slice object to apply.
imax : int
End of the range to which the slice is applied.
A `slice(None)` will return this many rows.
col : str
Name of the primary key column.
Assumed to start at 0 with increments of 1.

Returns
-------
where : str
WHERE clause for the query.
reverse : bool
If True the query result must be reversed because
the slice had a backwards direction.
"""
# Determine non-negative slice indices
istart, istop, istep = slc.indices(imax)
if istep < 0:
istop, istart = istart + 1, istop + 1
reverse = True
else:
reverse = False

# Aggregate conditions
conds = []
if istart > 0:
conds.append(f"{col}>={istart}")
conds.append(f"{col}<{istop}")
if istep != 1:
conds.append(f"modulo({col} - {istart}, {abs(istep)}) == 0")
return "WHERE " + " AND ".join(conds), reverse


class ClickHouseChain(Chain):
"""Represents an MCMC chain stored in ClickHouse."""

Expand Down Expand Up @@ -174,12 +214,17 @@ def _get_rows(
slc: slice = slice(None),
) -> numpy.ndarray:
self._commit()
data = self._client.execute(f"SELECT (`{var_name}`) FROM {self.cid};")
where, reverse = where_slice(slc, self._draw_idx)
data = self._client.execute(f"SELECT (`{var_name}`) FROM {self.cid} {where};")
draws = len(data)
if reverse:
data = reversed(data)

# Safety checks
# Without draws return empty arrays of the correct shape/dtype
if not draws:
raise Exception(f"No draws in chain {self.cid}.")
if is_rigid(nshape):
return numpy.empty(shape=[0] + nshape, dtype=dtype)
return numpy.array([], dtype=object)

# The unpacking must also account for non-rigid shapes
if is_rigid(nshape):
Expand All @@ -198,7 +243,7 @@ def _get_rows(
arr[:] = buffer
return arr
# Otherwise (identical shapes) we can collapse into one ndarray
return numpy.asarray(buffer, dtype=dtype)[slc]
return numpy.asarray(buffer, dtype=dtype)

def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
var = self.variables[var_name]
Expand Down
6 changes: 4 additions & 2 deletions mcbackend/test_backend_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ def test_insert_draw(self):
}
chain = chains[0]

with pytest.raises(Exception, match="No draws in chain"):
chain._get_rows("v1", [], "uint16")
# Get empty vector from empty chain
nodraws = chain._get_rows("v1", [], "uint16")
assert nodraws.shape == (0,)
assert nodraws.dtype == numpy.uint16

chain.append(draw)
assert len(chain._insert_queue) == 1
Expand Down
30 changes: 21 additions & 9 deletions mcbackend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,22 @@ def test__append_get_with_changelings(self, with_stats):
slice(None, None, None),
slice(2, None, None),
slice(2, 10, None),
slice(2, 15, 3),
slice(-8, None, None),
slice(2, 15, 3), # every 3rd
slice(15, 2, -3), # backwards every 3rd
slice(2, 15, -3), # empty
slice(-8, None, None), # the last 8
slice(-8, -2, 2),
slice(-50, -2, 2),
slice(15, 10),
slice(15, 10), # empty
slice(1, 1), # empty
],
)
def test__get_slicing(self, slc: slice):
rmeta = make_runmeta(
# "A" are just numbers to make diagnosis easier.
# "B" are dynamically shaped to cover the edge cases.
rmeta = RunMeta(
variables=[Variable("A", "uint8")],
sample_stats=[Variable("B", "uint8")],
sample_stats=[Variable("B", "uint8", [2, 0])],
data=[],
)
run = self.backend.init_run(rmeta)
Expand All @@ -204,19 +209,26 @@ def test__get_slicing(self, slc: slice):
# Generate draws and add them to the chain
N = 20
draws = [dict(A=n) for n in range(N)]
stats = [dict(B=n) for n in range(N)]
stats = [make_draw(rmeta.sample_stats) for n in range(N)]
for d, s in zip(draws, stats):
chain.append(d, s)
assert len(chain) == N

# slc=None in this test means "don't pass it".
# The implementations should default to slc=slice(None, None, None).
expected = numpy.arange(N, dtype="uint8")[slc or slice(None, None, None)]
kwargs = dict(slc=slc) if slc is not None else {}
act_draws = chain.get_draws("A", **kwargs)
act_stats = chain.get_stats("B", **kwargs)
numpy.testing.assert_array_equal(act_draws, expected)
numpy.testing.assert_array_equal(act_stats, expected)
expected_draws = [d["A"] for d in draws][slc or slice(None, None, None)]
expected_stats = [s["B"] for s in stats][slc or slice(None, None, None)]
# Variable "A" has a rigid shape
numpy.testing.assert_array_equal(act_draws, expected_draws)
# Stat "B" is dynamically shaped, which means we're dealing with
# dtype=object arrays. These must be checked elementwise.
assert len(act_stats) == len(expected_stats)
assert act_stats.dtype == object
for a, e in zip(act_stats, expected_stats):
numpy.testing.assert_array_equal(a, e)
pass

def test__get_chains(self):
Expand Down