Skip to content

Commit 4a44866

Browse files
Revise slice test to cover empty and reverse selection
To align behaviors, the `ClickHouseChain` may now return empty arrays from empty chains or selections.
1 parent 1b2fe3b commit 4a44866

File tree

3 files changed

+30
-14
lines changed

3 files changed

+30
-14
lines changed

mcbackend/backends/clickhouse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,11 @@ def _get_rows(
177177
data = self._client.execute(f"SELECT (`{var_name}`) FROM {self.cid};")
178178
draws = len(data)
179179

180-
# Safety checks
180+
# Without draws return empty arrays of the correct shape/dtype
181181
if not draws:
182-
raise Exception(f"No draws in chain {self.cid}.")
182+
if is_rigid(nshape):
183+
return numpy.empty(shape=[0] + nshape, dtype=dtype)
184+
return numpy.array([], dtype=object)
183185

184186
# The unpacking must also account for non-rigid shapes
185187
if is_rigid(nshape):
@@ -196,7 +198,7 @@ def _get_rows(
196198
# To circumvent NumPy issue #19113
197199
arr = numpy.empty(draws, dtype=object)
198200
arr[:] = buffer
199-
return arr
201+
return arr[slc]
200202
# Otherwise (identical shapes) we can collapse into one ndarray
201203
return numpy.asarray(buffer, dtype=dtype)[slc]
202204

mcbackend/test_backend_clickhouse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,10 @@ def test_insert_draw(self):
237237
}
238238
chain = chains[0]
239239

240-
with pytest.raises(Exception, match="No draws in chain"):
241-
chain._get_rows("v1", [], "uint16")
240+
# Get empty vector from empty chain
241+
nodraws = chain._get_rows("v1", [], "uint16")
242+
assert nodraws.shape == (0,)
243+
assert nodraws.dtype == numpy.uint16
242244

243245
chain.append(draw)
244246
assert len(chain._insert_queue) == 1

mcbackend/test_utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,22 @@ def test__append_get_with_changelings(self, with_stats):
185185
slice(None, None, None),
186186
slice(2, None, None),
187187
slice(2, 10, None),
188-
slice(2, 15, 3),
189-
slice(-8, None, None),
188+
slice(2, 15, 3), # every 3rd
189+
slice(15, 2, -3), # backwards every 3rd
190+
slice(2, 15, -3), # empty
191+
slice(-8, None, None), # the last 8
190192
slice(-8, -2, 2),
191193
slice(-50, -2, 2),
192-
slice(15, 10),
194+
slice(15, 10), # empty
195+
slice(1, 1), # empty
193196
],
194197
)
195198
def test__get_slicing(self, slc: slice):
196-
rmeta = make_runmeta(
199+
# "A" are just numbers to make diagnosis easier.
200+
# "B" are dynamically shaped to cover the edge cases.
201+
rmeta = RunMeta(
197202
variables=[Variable("A", "uint8")],
198-
sample_stats=[Variable("B", "uint8")],
203+
sample_stats=[Variable("B", "uint8", [2, 0])],
199204
data=[],
200205
)
201206
run = self.backend.init_run(rmeta)
@@ -204,19 +209,26 @@ def test__get_slicing(self, slc: slice):
204209
# Generate draws and add them to the chain
205210
N = 20
206211
draws = [dict(A=n) for n in range(N)]
207-
stats = [dict(B=n) for n in range(N)]
212+
stats = [make_draw(rmeta.sample_stats) for n in range(N)]
208213
for d, s in zip(draws, stats):
209214
chain.append(d, s)
210215
assert len(chain) == N
211216

212217
# slc=None in this test means "don't pass it".
213218
# The implementations should default to slc=slice(None, None, None).
214-
expected = numpy.arange(N, dtype="uint8")[slc or slice(None, None, None)]
215219
kwargs = dict(slc=slc) if slc is not None else {}
216220
act_draws = chain.get_draws("A", **kwargs)
217221
act_stats = chain.get_stats("B", **kwargs)
218-
numpy.testing.assert_array_equal(act_draws, expected)
219-
numpy.testing.assert_array_equal(act_stats, expected)
222+
expected_draws = [d["A"] for d in draws][slc or slice(None, None, None)]
223+
expected_stats = [s["B"] for s in stats][slc or slice(None, None, None)]
224+
# Variable "A" has a rigid shape
225+
numpy.testing.assert_array_equal(act_draws, expected_draws)
226+
# Stat "B" is dynamically shaped, which means we're dealing with
227+
# dtype=object arrays. These must be checked elementwise.
228+
assert len(act_stats) == len(expected_stats)
229+
assert act_stats.dtype == object
230+
for a, e in zip(act_stats, expected_stats):
231+
numpy.testing.assert_array_equal(a, e)
220232
pass
221233

222234
def test__get_chains(self):

0 commit comments

Comments
 (0)