@@ -185,17 +185,22 @@ def test__append_get_with_changelings(self, with_stats):
185
185
slice (None , None , None ),
186
186
slice (2 , None , None ),
187
187
slice (2 , 10 , None ),
188
- slice (2 , 15 , 3 ),
189
- slice (- 8 , None , None ),
188
+ slice (2 , 15 , 3 ), # every 3rd
189
+ slice (15 , 2 , - 3 ), # backwards every 3rd
190
+ slice (2 , 15 , - 3 ), # empty
191
+ slice (- 8 , None , None ), # the last 8
190
192
slice (- 8 , - 2 , 2 ),
191
193
slice (- 50 , - 2 , 2 ),
192
- slice (15 , 10 ),
194
+ slice (15 , 10 ), # empty
195
+ slice (1 , 1 ), # empty
193
196
],
194
197
)
195
198
def test__get_slicing (self , slc : slice ):
196
- rmeta = make_runmeta (
199
+ # "A" are just numbers to make diagnosis easier.
200
+ # "B" are dynamically shaped to cover the edge cases.
201
+ rmeta = RunMeta (
197
202
variables = [Variable ("A" , "uint8" )],
198
- sample_stats = [Variable ("B" , "uint8" )],
203
+ sample_stats = [Variable ("B" , "uint8" , [ 2 , 0 ] )],
199
204
data = [],
200
205
)
201
206
run = self .backend .init_run (rmeta )
@@ -204,19 +209,26 @@ def test__get_slicing(self, slc: slice):
204
209
# Generate draws and add them to the chain
205
210
N = 20
206
211
draws = [dict (A = n ) for n in range (N )]
207
- stats = [dict ( B = n ) for n in range (N )]
212
+ stats = [make_draw ( rmeta . sample_stats ) for n in range (N )]
208
213
for d , s in zip (draws , stats ):
209
214
chain .append (d , s )
210
215
assert len (chain ) == N
211
216
212
217
# slc=None in this test means "don't pass it".
213
218
# The implementations should default to slc=slice(None, None, None).
214
- expected = numpy .arange (N , dtype = "uint8" )[slc or slice (None , None , None )]
215
219
kwargs = dict (slc = slc ) if slc is not None else {}
216
220
act_draws = chain .get_draws ("A" , ** kwargs )
217
221
act_stats = chain .get_stats ("B" , ** kwargs )
218
- numpy .testing .assert_array_equal (act_draws , expected )
219
- numpy .testing .assert_array_equal (act_stats , expected )
222
+ expected_draws = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
223
+ expected_stats = [s ["B" ] for s in stats ][slc or slice (None , None , None )]
224
+ # Variable "A" has a rigid shape
225
+ numpy .testing .assert_array_equal (act_draws , expected_draws )
226
+ # Stat "B" is dynamically shaped, which means we're dealing with
227
+ # dtype=object arrays. These must be checked elementwise.
228
+ assert len (act_stats ) == len (expected_stats )
229
+ assert act_stats .dtype == object
230
+ for a , e in zip (act_stats , expected_stats ):
231
+ numpy .testing .assert_array_equal (a , e )
220
232
pass
221
233
222
234
def test__get_chains (self ):
0 commit comments