@@ -29,6 +29,8 @@ def make_runmeta(*, flexibility: bool = False, **kwargs) -> RunMeta:
29
29
Variable ("accepted" , "bool" , list ((3 ,)), dims = ["sampler" ]),
30
30
# But some stats may refer to the iteration.
31
31
Variable ("logp" , "float64" , []),
32
+ # String dtypes may be used for more complex information
33
+ Variable ("message" , "str" ),
32
34
],
33
35
data = [
34
36
DataVariable (
@@ -60,8 +62,16 @@ def make_draw(variables: Sequence[Variable]):
60
62
)
61
63
if "float" in var .dtype :
62
64
draw [var .name ] = numpy .random .normal (size = dshape ).astype (var .dtype )
65
+ elif var .dtype == "str" :
66
+ alphabet = tuple ("abcdef#+*/'" )
67
+ words = [
68
+ "" .join (numpy .random .choice (alphabet , size = numpy .random .randint (3 , 10 )))
69
+ for _ in range (int (numpy .prod (dshape )))
70
+ ]
71
+ draw [var .name ] = numpy .array (words , dtype = var .dtype ).reshape (dshape )
63
72
else :
64
73
draw [var .name ] = numpy .random .randint (low = 0 , high = 100 , size = dshape ).astype (var .dtype )
74
+ assert draw [var .name ].shape == dshape
65
75
return draw
66
76
67
77
@@ -149,7 +159,7 @@ def test__append_get_with_changelings(self, with_stats):
149
159
expected = [draw [var .name ] for draw in draws ]
150
160
actual = chain .get_draws (var .name )
151
161
assert isinstance (actual , numpy .ndarray )
152
- if var .name == "changeling " :
162
+ if not is_rigid ( var .shape ) or var . dtype == "str " :
153
163
# Non-ridid variables are returned as object-arrays.
154
164
assert actual .shape == (len (expected ),)
155
165
assert actual .dtype == object
@@ -166,9 +176,13 @@ def test__append_get_with_changelings(self, with_stats):
166
176
expected = [stat [var .name ] for stat in stats ]
167
177
actual = chain .get_stats (var .name )
168
178
assert isinstance (actual , numpy .ndarray )
169
- if is_rigid ( var .shape ) :
179
+ if var .dtype == "str" :
170
180
assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
171
- assert actual .dtype == var .dtype
181
+ # String dtypes have strange names
182
+ assert "str" in actual .dtype .name
183
+ elif is_rigid (var .shape ):
184
+ assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
185
+ assert actual .dtype .name == var .dtype
172
186
numpy .testing .assert_array_equal (actual , expected )
173
187
else :
174
188
# Non-ridid variables are returned as object-arrays.
@@ -200,7 +214,7 @@ def test__get_slicing(self, slc: slice):
200
214
# "A" are just numbers to make diagnosis easier.
201
215
# "B" are dynamically shaped to cover the edge cases.
202
216
rmeta = RunMeta (
203
- variables = [Variable ("A" , "uint8" )],
217
+ variables = [Variable ("A" , "uint8" ), Variable ( "M" , "str" , [ 2 , 3 ]) ],
204
218
sample_stats = [Variable ("B" , "uint8" , [2 , 0 ])],
205
219
data = [],
206
220
)
@@ -209,7 +223,7 @@ def test__get_slicing(self, slc: slice):
209
223
210
224
# Generate draws and add them to the chain
211
225
N = 20
212
- draws = [dict ( A = n ) for n in range (N )]
226
+ draws = [make_draw ( rmeta . variables ) for n in range (N )]
213
227
stats = [make_draw (rmeta .sample_stats ) for n in range (N )]
214
228
for d , s in zip (draws , stats ):
215
229
chain .append (d , s )
@@ -218,12 +232,25 @@ def test__get_slicing(self, slc: slice):
218
232
# slc=None in this test means "don't pass it".
219
233
# The implementations should default to slc=slice(None, None, None).
220
234
kwargs = dict (slc = slc ) if slc is not None else {}
221
- act_draws = chain .get_draws ("A" , ** kwargs )
235
+ act_draws_A = chain .get_draws ("A" , ** kwargs )
236
+ act_draws_M = chain .get_draws ("M" , ** kwargs )
222
237
act_stats = chain .get_stats ("B" , ** kwargs )
223
- expected_draws = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
238
+ expected_draws_A = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
239
+ expected_draws_M = [d ["M" ] for d in draws ][slc or slice (None , None , None )]
224
240
expected_stats = [s ["B" ] for s in stats ][slc or slice (None , None , None )]
241
+
225
242
# Variable "A" has a rigid shape
226
- numpy .testing .assert_array_equal (act_draws , expected_draws )
243
+ if expected_draws_A :
244
+ numpy .testing .assert_array_equal (act_draws_A , expected_draws_A )
245
+ else :
246
+ assert len (act_draws_A ) == 0
247
+
248
+ # Variable "M" is a string matrix
249
+ if expected_draws_M :
250
+ numpy .testing .assert_array_equal (act_draws_M , expected_draws_M )
251
+ else :
252
+ assert len (act_draws_M ) == 0
253
+
227
254
# Stat "B" is dynamically shaped, which means we're dealing with
228
255
# dtype=object arrays. These must be checked elementwise.
229
256
assert len (act_stats ) == len (expected_stats )
@@ -256,6 +283,7 @@ def test__to_inferencedata(self):
256
283
sample_stats = [
257
284
Variable ("tune" , "bool" ),
258
285
Variable ("sampler_0__logp" , "float32" ),
286
+ Variable ("warning" , "str" ),
259
287
],
260
288
)
261
289
run = self .backend .init_run (rmeta )
0 commit comments