Skip to content

Commit d4fbab6

Browse files
committed
Implement size and RNG copy
1 parent caeaf02 commit d4fbab6

File tree

5 files changed

+522
-769
lines changed

5 files changed

+522
-769
lines changed

pytensor/link/numba/dispatch/elemwise.py

+2-152
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,21 @@
88

99
import numba
1010
import numpy as np
11-
from numba import TypingError, types
12-
from numba.core import cgutils
1311
from numba.core.extending import overload
14-
from numba.np import arrayobj
1512
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1613

1714
from pytensor import config
1815
from pytensor.graph.basic import Apply
1916
from pytensor.graph.op import Op
2017
from pytensor.link.numba.dispatch import basic as numba_basic
21-
from pytensor.link.numba.dispatch import elemwise_codegen
2218
from pytensor.link.numba.dispatch.basic import (
2319
create_numba_signature,
2420
create_tuple_creator,
2521
numba_funcify,
2622
numba_njit,
2723
use_optimized_cheap_pass,
2824
)
25+
from pytensor.link.numba.dispatch.vectorize_codegen import _vectorized
2926
from pytensor.link.utils import compile_function_src, get_name_for_object
3027
from pytensor.scalar.basic import (
3128
AND,
@@ -474,154 +471,6 @@ def axis_apply_fn(x):
474471
}
475472

476473

477-
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
478-
def _vectorized(
479-
typingctx,
480-
scalar_func,
481-
input_bc_patterns,
482-
output_bc_patterns,
483-
output_dtypes,
484-
inplace_pattern,
485-
inputs,
486-
):
487-
arg_types = [
488-
scalar_func,
489-
input_bc_patterns,
490-
output_bc_patterns,
491-
output_dtypes,
492-
inplace_pattern,
493-
inputs,
494-
]
495-
496-
if not isinstance(input_bc_patterns, types.Literal):
497-
raise TypingError("input_bc_patterns must be literal.")
498-
input_bc_patterns = input_bc_patterns.literal_value
499-
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
500-
501-
if not isinstance(output_bc_patterns, types.Literal):
502-
raise TypeError("output_bc_patterns must be literal.")
503-
output_bc_patterns = output_bc_patterns.literal_value
504-
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
505-
506-
if not isinstance(output_dtypes, types.Literal):
507-
raise TypeError("output_dtypes must be literal.")
508-
output_dtypes = output_dtypes.literal_value
509-
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
510-
511-
if not isinstance(inplace_pattern, types.Literal):
512-
raise TypeError("inplace_pattern must be literal.")
513-
inplace_pattern = inplace_pattern.literal_value
514-
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
515-
516-
n_outputs = len(output_bc_patterns)
517-
518-
if not len(inputs) > 0:
519-
raise TypingError("Empty argument list to elemwise op.")
520-
521-
if not n_outputs > 0:
522-
raise TypingError("Empty list of outputs for elemwise op.")
523-
524-
if not all(isinstance(input, types.Array) for input in inputs):
525-
raise TypingError("Inputs to elemwise must be arrays.")
526-
ndim = inputs[0].ndim
527-
528-
if not all(input.ndim == ndim for input in inputs):
529-
raise TypingError("Inputs to elemwise must have the same rank.")
530-
531-
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
532-
raise TypingError("Invalid output broadcasting pattern.")
533-
534-
scalar_signature = typingctx.resolve_function_type(
535-
scalar_func, [in_type.dtype for in_type in inputs], {}
536-
)
537-
538-
# So we can access the constant values in codegen...
539-
input_bc_patterns_val = input_bc_patterns
540-
output_bc_patterns_val = output_bc_patterns
541-
output_dtypes_val = output_dtypes
542-
inplace_pattern_val = inplace_pattern
543-
input_types = inputs
544-
545-
def codegen(
546-
ctx,
547-
builder,
548-
sig,
549-
args,
550-
):
551-
[_, _, _, _, _, inputs] = args
552-
inputs = cgutils.unpack_tuple(builder, inputs)
553-
inputs = [
554-
arrayobj.make_array(ty)(ctx, builder, val)
555-
for ty, val in zip(input_types, inputs)
556-
]
557-
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
558-
559-
iter_shape = elemwise_codegen.compute_itershape(
560-
ctx,
561-
builder,
562-
in_shapes,
563-
input_bc_patterns_val,
564-
)
565-
566-
outputs, output_types = elemwise_codegen.make_outputs(
567-
ctx,
568-
builder,
569-
iter_shape,
570-
output_bc_patterns_val,
571-
output_dtypes_val,
572-
inplace_pattern_val,
573-
inputs,
574-
input_types,
575-
)
576-
577-
elemwise_codegen.make_loop_call(
578-
typingctx,
579-
ctx,
580-
builder,
581-
scalar_func,
582-
scalar_signature,
583-
iter_shape,
584-
inputs,
585-
outputs,
586-
input_bc_patterns_val,
587-
output_bc_patterns_val,
588-
input_types,
589-
output_types,
590-
)
591-
592-
if len(outputs) == 1:
593-
if inplace_pattern:
594-
assert inplace_pattern[0][0] == 0
595-
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
596-
return outputs[0]._getvalue()
597-
598-
for inplace_idx in dict(inplace_pattern):
599-
ctx.nrt.incref(
600-
builder,
601-
sig.return_type.types[inplace_idx],
602-
outputs[inplace_idx]._get_value(),
603-
)
604-
return ctx.make_tuple(
605-
builder, sig.return_type, [out._getvalue() for out in outputs]
606-
)
607-
608-
ret_types = [
609-
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
610-
for dtype in output_dtypes
611-
]
612-
613-
for output_idx, input_idx in inplace_pattern:
614-
ret_types[output_idx] = input_types[input_idx]
615-
616-
ret_type = types.Tuple(ret_types)
617-
618-
if len(output_dtypes) == 1:
619-
ret_type = ret_type.types[0]
620-
sig = ret_type(*arg_types)
621-
622-
return sig, codegen
623-
624-
625474
@numba_funcify.register(Elemwise)
626475
def numba_funcify_Elemwise(op, node, **kwargs):
627476
# Creating a new scalar node is more involved and unnecessary
@@ -665,6 +514,7 @@ def elemwise_wrapper(*inputs):
665514
output_bc_patterns_enc,
666515
output_dtypes_enc,
667516
inplace_pattern_enc,
517+
(),
668518
inputs,
669519
)
670520

0 commit comments

Comments
 (0)