@@ -29,6 +29,8 @@ def make_runmeta(*, flexibility: bool = False, **kwargs) -> RunMeta:
29
29
Variable ("accepted" , "bool" , list ((3 ,)), dims = ["sampler" ]),
30
30
# But some stats may refer to the iteration.
31
31
Variable ("logp" , "float64" , []),
32
+ # String dtypes may be used for more complex information
33
+ Variable ("message" , "str" ),
32
34
],
33
35
data = [
34
36
DataVariable (
@@ -60,8 +62,13 @@ def make_draw(variables: Sequence[Variable]):
60
62
)
61
63
if "float" in var .dtype :
62
64
draw [var .name ] = numpy .random .normal (size = dshape ).astype (var .dtype )
65
+ elif var .dtype == "str" :
66
+ alphabet = tuple ("abcdef#+*/'" )
67
+ words = ["" .join (numpy .random .choice (alphabet , size = numpy .random .randint (3 , 10 )))]
68
+ draw [var .name ] = numpy .array (words ).reshape (dshape )
63
69
else :
64
70
draw [var .name ] = numpy .random .randint (low = 0 , high = 100 , size = dshape ).astype (var .dtype )
71
+ assert draw [var .name ].shape == dshape
65
72
return draw
66
73
67
74
@@ -149,7 +156,7 @@ def test__append_get_with_changelings(self, with_stats):
149
156
expected = [draw [var .name ] for draw in draws ]
150
157
actual = chain .get_draws (var .name )
151
158
assert isinstance (actual , numpy .ndarray )
152
- if var .name == "changeling " :
159
+ if not is_rigid ( var .shape ) or var . dtype == "str " :
153
160
# Non-ridid variables are returned as object-arrays.
154
161
assert actual .shape == (len (expected ),)
155
162
assert actual .dtype == object
@@ -166,9 +173,13 @@ def test__append_get_with_changelings(self, with_stats):
166
173
expected = [stat [var .name ] for stat in stats ]
167
174
actual = chain .get_stats (var .name )
168
175
assert isinstance (actual , numpy .ndarray )
169
- if is_rigid ( var .shape ) :
176
+ if var .dtype == "str" :
170
177
assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
171
- assert actual .dtype == var .dtype
178
+ # String dtypes have strange names
179
+ assert "str" in actual .dtype .name
180
+ elif is_rigid (var .shape ):
181
+ assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
182
+ assert actual .dtype .name == var .dtype
172
183
numpy .testing .assert_array_equal (actual , expected )
173
184
else :
174
185
# Non-ridid variables are returned as object-arrays.
@@ -200,7 +211,7 @@ def test__get_slicing(self, slc: slice):
200
211
# "A" are just numbers to make diagnosis easier.
201
212
# "B" are dynamically shaped to cover the edge cases.
202
213
rmeta = RunMeta (
203
- variables = [Variable ("A" , "uint8" )],
214
+ variables = [Variable ("A" , "uint8" ), Variable ( "M" , "str" , [ 2 , 3 ]) ],
204
215
sample_stats = [Variable ("B" , "uint8" , [2 , 0 ])],
205
216
data = [],
206
217
)
@@ -209,7 +220,7 @@ def test__get_slicing(self, slc: slice):
209
220
210
221
# Generate draws and add them to the chain
211
222
N = 20
212
- draws = [dict (A = n ) for n in range (N )]
223
+ draws = [dict (A = numpy . array ( n ) ) for n in range (N )]
213
224
stats = [make_draw (rmeta .sample_stats ) for n in range (N )]
214
225
for d , s in zip (draws , stats ):
215
226
chain .append (d , s )
@@ -222,8 +233,13 @@ def test__get_slicing(self, slc: slice):
222
233
act_stats = chain .get_stats ("B" , ** kwargs )
223
234
expected_draws = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
224
235
expected_stats = [s ["B" ] for s in stats ][slc or slice (None , None , None )]
236
+
225
237
# Variable "A" has a rigid shape
226
238
numpy .testing .assert_array_equal (act_draws , expected_draws )
239
+
240
+ # Variable "M" is a string matrix
241
+ numpy .testing .assert_array_equal (act_draws , expected_draws )
242
+
227
243
# Stat "B" is dynamically shaped, which means we're dealing with
228
244
# dtype=object arrays. These must be checked elementwise.
229
245
assert len (act_stats ) == len (expected_stats )
@@ -256,6 +272,7 @@ def test__to_inferencedata(self):
256
272
sample_stats = [
257
273
Variable ("tune" , "bool" ),
258
274
Variable ("sampler_0__logp" , "float32" ),
275
+ Variable ("warning" , "str" ),
259
276
],
260
277
)
261
278
run = self .backend .init_run (rmeta )
0 commit comments