2
2
import pickle
3
3
from collections .abc import Callable
4
4
from copy import copy
5
+ from functools import singledispatch
5
6
from textwrap import dedent , indent
6
7
from typing import Any
7
8
@@ -168,25 +169,10 @@ def impl(rng):
168
169
return impl
169
170
170
171
171
- @numba_funcify .register (ptr .RandomVariable )
172
- def numba_funcify_RandomVariable (op : RandomVariable , node , ** kwargs ):
173
- _ , size , _ , * args = node .inputs
174
- # None sizes are represented as empty tuple for the time being
175
- # https://github.com/pymc-devs/pytensor/issues/568
176
- [size_len ] = size .type .shape
177
- size_is_None = size_len == 0
178
-
179
- inplace = op .inplace
180
-
181
- if op .ndim_supp > 0 :
182
- raise NotImplementedError ("Multivariate random variables not supported yet" )
183
-
184
- # if any(ndim_param > 0 for ndim_param in op.ndims_params):
185
- # raise NotImplementedError(
186
- # "Random variables with non scalar core inputs not supported yet"
187
- # )
172
+ @singledispatch
173
+ def core_rv_fn (op : Op ):
174
+ """Return the core function for a random variable operation."""
188
175
189
- # TODO: Use dispatch, so users can define the core case
190
176
# Use string repr for default like below
191
177
# inner_code = dedent(f"""
192
178
# @numba_basic.numba_njit
@@ -197,15 +183,67 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
197
183
# exec(inner_code)
198
184
# scalar_op_fn = locals()['scalar_op_fn']
199
185
200
- # @numba_basic.numba_njit
201
- # def core_op_fn(rng, mu, scale):
202
- # return rng.normal(mu, scale)
186
+ raise NotImplementedError ()
187
+
203
188
189
+ @core_rv_fn .register (ptr .NormalRV )
190
+ def core_NormalRV (op ):
204
191
@numba_basic .numba_njit
205
- def core_op_fn (rng , p ):
192
+ def random_fn (rng , mu , scale , out ):
193
+ out [...] = rng .normal (mu , scale )
194
+
195
+ random_fn .handles_out = True
196
+ return random_fn
197
+
198
+
199
+ @core_rv_fn .register (ptr .CategoricalRV )
200
+ def core_CategoricalRV (op ):
201
+ @numba_basic .numba_njit
202
+ def random_fn (rng , p , out ):
206
203
unif_sample = rng .uniform (0 , 1 )
207
- return np .searchsorted (np .cumsum (p ), unif_sample )
204
+ # TODO: Check if LLVM can lift constant cumsum(p) out of the loop
205
+ out [...] = np .searchsorted (np .cumsum (p ), unif_sample )
206
+
207
+ random_fn .handles_out = True
208
+ return random_fn
209
+
210
+
211
+ @core_rv_fn .register (ptr .MvNormalRV )
212
+ def core_MvNormalRV (op ):
213
+ @numba .njit
214
+ def random_fn (rng , mean , cov , out ):
215
+ chol = np .linalg .cholesky (cov )
216
+ stdnorm = rng .normal (size = cov .shape [- 1 ])
217
+ # np.dot(chol, stdnorm, out=out)
218
+ # out[...] += mean
219
+ out [...] = mean + np .dot (chol , stdnorm )
208
220
221
+ random_fn .handles_out = True
222
+ return random_fn
223
+
224
+
225
+ @numba_funcify .register (ptr .RandomVariable )
226
+ def numba_funcify_RandomVariable (op : RandomVariable , node , ** kwargs ):
227
+ _ , size , _ , * args = node .inputs
228
+ # None sizes are represented as empty tuple for the time being
229
+ # https://github.com/pymc-devs/pytensor/issues/568
230
+ [size_len ] = size .type .shape
231
+ size_is_None = size_len == 0
232
+
233
+ inplace = op .inplace
234
+
235
+ # TODO: Add core_shape to node.inputs
236
+ if op .ndim_supp > 0 :
237
+ raise NotImplementedError ("Multivariate RandomVariable not implemented yet" )
238
+
239
+ # TODO: Create a wrapper (string processing?) that takes a core function without outputs
240
+ # and saves those outputs in the variables passed by `_vectorized`
241
+ core_op_fn = core_rv_fn (op )
242
+ if not getattr (core_op_fn , "handles_out" , False ):
243
+ # core_op_fn = store_core_outputs(op, core_op_fn)
244
+ raise NotImplementedError ()
245
+
246
+ # TODO: Refactor this code, it's the same with Elemwise
209
247
batch_ndim = node .default_output ().ndim - op .ndim_supp
210
248
output_bc_patterns = ((False ,) * batch_ndim ,)
211
249
input_bc_patterns = tuple (
@@ -234,12 +272,14 @@ def random_wrapper(rng, size, dtype, *inputs):
234
272
inplace_pattern_enc ,
235
273
(rng ,),
236
274
inputs ,
237
- None if size_is_None else numba_ndarray .to_fixed_tuple (size , size_len ),
275
+ ((),), # TODO: correct core_shapes
276
+ None
277
+ if size_is_None
278
+ else numba_ndarray .to_fixed_tuple (size , size_len ), # size
238
279
)
239
280
return rng , draws
240
281
241
282
def random (rng , size , dtype , * inputs ):
242
- # TODO: Add code that will be tested for coverage
243
283
pass
244
284
245
285
@overload (random )
@@ -330,35 +370,6 @@ def body_fn(a):
330
370
)
331
371
332
372
333
- # @numba_funcify.register(ptr.CategoricalRV)
334
- def numba_funcify_CategoricalRV (op , node , ** kwargs ):
335
- out_dtype = node .outputs [1 ].type .numpy_dtype
336
- size_len = int (get_vector_length (node .inputs [1 ]))
337
- p_ndim = node .inputs [- 1 ].ndim
338
-
339
- @numba_basic .numba_njit
340
- def categorical_rv (rng , size , dtype , p ):
341
- if not size_len :
342
- size_tpl = p .shape [:- 1 ]
343
- else :
344
- size_tpl = numba_ndarray .to_fixed_tuple (size , size_len )
345
- p = np .broadcast_to (p , size_tpl + p .shape [- 1 :])
346
-
347
- # Workaround https://github.com/numba/numba/issues/8975
348
- if not size_len and p_ndim == 1 :
349
- unif_samples = np .asarray (np .random .uniform (0 , 1 ))
350
- else :
351
- unif_samples = np .random .uniform (0 , 1 , size_tpl )
352
-
353
- res = np .empty (size_tpl , dtype = out_dtype )
354
- for idx in np .ndindex (* size_tpl ):
355
- res [idx ] = np .searchsorted (np .cumsum (p [idx ]), unif_samples [idx ])
356
-
357
- return (rng , res )
358
-
359
- return categorical_rv
360
-
361
-
362
373
@numba_funcify .register (ptr .DirichletRV )
363
374
def numba_funcify_DirichletRV (op , node , ** kwargs ):
364
375
out_dtype = node .outputs [1 ].type .numpy_dtype
0 commit comments