|
8 | 8 |
|
9 | 9 | import numba
|
10 | 10 | import numpy as np
|
11 |
| -from numba import TypingError, types |
12 |
| -from numba.core import cgutils |
13 | 11 | from numba.core.extending import overload
|
14 |
| -from numba.np import arrayobj |
15 | 12 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
|
16 | 13 |
|
17 | 14 | from pytensor import config
|
18 | 15 | from pytensor.graph.basic import Apply
|
19 | 16 | from pytensor.graph.op import Op
|
20 | 17 | from pytensor.link.numba.dispatch import basic as numba_basic
|
21 |
| -from pytensor.link.numba.dispatch import elemwise_codegen |
22 | 18 | from pytensor.link.numba.dispatch.basic import (
|
23 | 19 | create_numba_signature,
|
24 | 20 | create_tuple_creator,
|
25 | 21 | numba_funcify,
|
26 | 22 | numba_njit,
|
27 | 23 | use_optimized_cheap_pass,
|
28 | 24 | )
|
| 25 | +from pytensor.link.numba.dispatch.vectorize_codegen import _vectorized |
29 | 26 | from pytensor.link.utils import compile_function_src, get_name_for_object
|
30 | 27 | from pytensor.scalar.basic import (
|
31 | 28 | AND,
|
@@ -474,154 +471,6 @@ def axis_apply_fn(x):
|
474 | 471 | }
|
475 | 472 |
|
476 | 473 |
|
477 |
| -@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) |
478 |
| -def _vectorized( |
479 |
| - typingctx, |
480 |
| - scalar_func, |
481 |
| - input_bc_patterns, |
482 |
| - output_bc_patterns, |
483 |
| - output_dtypes, |
484 |
| - inplace_pattern, |
485 |
| - inputs, |
486 |
| -): |
487 |
| - arg_types = [ |
488 |
| - scalar_func, |
489 |
| - input_bc_patterns, |
490 |
| - output_bc_patterns, |
491 |
| - output_dtypes, |
492 |
| - inplace_pattern, |
493 |
| - inputs, |
494 |
| - ] |
495 |
| - |
496 |
| - if not isinstance(input_bc_patterns, types.Literal): |
497 |
| - raise TypingError("input_bc_patterns must be literal.") |
498 |
| - input_bc_patterns = input_bc_patterns.literal_value |
499 |
| - input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) |
500 |
| - |
501 |
| - if not isinstance(output_bc_patterns, types.Literal): |
502 |
| - raise TypeError("output_bc_patterns must be literal.") |
503 |
| - output_bc_patterns = output_bc_patterns.literal_value |
504 |
| - output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) |
505 |
| - |
506 |
| - if not isinstance(output_dtypes, types.Literal): |
507 |
| - raise TypeError("output_dtypes must be literal.") |
508 |
| - output_dtypes = output_dtypes.literal_value |
509 |
| - output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) |
510 |
| - |
511 |
| - if not isinstance(inplace_pattern, types.Literal): |
512 |
| - raise TypeError("inplace_pattern must be literal.") |
513 |
| - inplace_pattern = inplace_pattern.literal_value |
514 |
| - inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) |
515 |
| - |
516 |
| - n_outputs = len(output_bc_patterns) |
517 |
| - |
518 |
| - if not len(inputs) > 0: |
519 |
| - raise TypingError("Empty argument list to elemwise op.") |
520 |
| - |
521 |
| - if not n_outputs > 0: |
522 |
| - raise TypingError("Empty list of outputs for elemwise op.") |
523 |
| - |
524 |
| - if not all(isinstance(input, types.Array) for input in inputs): |
525 |
| - raise TypingError("Inputs to elemwise must be arrays.") |
526 |
| - ndim = inputs[0].ndim |
527 |
| - |
528 |
| - if not all(input.ndim == ndim for input in inputs): |
529 |
| - raise TypingError("Inputs to elemwise must have the same rank.") |
530 |
| - |
531 |
| - if not all(len(pattern) == ndim for pattern in output_bc_patterns): |
532 |
| - raise TypingError("Invalid output broadcasting pattern.") |
533 |
| - |
534 |
| - scalar_signature = typingctx.resolve_function_type( |
535 |
| - scalar_func, [in_type.dtype for in_type in inputs], {} |
536 |
| - ) |
537 |
| - |
538 |
| - # So we can access the constant values in codegen... |
539 |
| - input_bc_patterns_val = input_bc_patterns |
540 |
| - output_bc_patterns_val = output_bc_patterns |
541 |
| - output_dtypes_val = output_dtypes |
542 |
| - inplace_pattern_val = inplace_pattern |
543 |
| - input_types = inputs |
544 |
| - |
545 |
| - def codegen( |
546 |
| - ctx, |
547 |
| - builder, |
548 |
| - sig, |
549 |
| - args, |
550 |
| - ): |
551 |
| - [_, _, _, _, _, inputs] = args |
552 |
| - inputs = cgutils.unpack_tuple(builder, inputs) |
553 |
| - inputs = [ |
554 |
| - arrayobj.make_array(ty)(ctx, builder, val) |
555 |
| - for ty, val in zip(input_types, inputs) |
556 |
| - ] |
557 |
| - in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] |
558 |
| - |
559 |
| - iter_shape = elemwise_codegen.compute_itershape( |
560 |
| - ctx, |
561 |
| - builder, |
562 |
| - in_shapes, |
563 |
| - input_bc_patterns_val, |
564 |
| - ) |
565 |
| - |
566 |
| - outputs, output_types = elemwise_codegen.make_outputs( |
567 |
| - ctx, |
568 |
| - builder, |
569 |
| - iter_shape, |
570 |
| - output_bc_patterns_val, |
571 |
| - output_dtypes_val, |
572 |
| - inplace_pattern_val, |
573 |
| - inputs, |
574 |
| - input_types, |
575 |
| - ) |
576 |
| - |
577 |
| - elemwise_codegen.make_loop_call( |
578 |
| - typingctx, |
579 |
| - ctx, |
580 |
| - builder, |
581 |
| - scalar_func, |
582 |
| - scalar_signature, |
583 |
| - iter_shape, |
584 |
| - inputs, |
585 |
| - outputs, |
586 |
| - input_bc_patterns_val, |
587 |
| - output_bc_patterns_val, |
588 |
| - input_types, |
589 |
| - output_types, |
590 |
| - ) |
591 |
| - |
592 |
| - if len(outputs) == 1: |
593 |
| - if inplace_pattern: |
594 |
| - assert inplace_pattern[0][0] == 0 |
595 |
| - ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) |
596 |
| - return outputs[0]._getvalue() |
597 |
| - |
598 |
| - for inplace_idx in dict(inplace_pattern): |
599 |
| - ctx.nrt.incref( |
600 |
| - builder, |
601 |
| - sig.return_type.types[inplace_idx], |
602 |
| - outputs[inplace_idx]._get_value(), |
603 |
| - ) |
604 |
| - return ctx.make_tuple( |
605 |
| - builder, sig.return_type, [out._getvalue() for out in outputs] |
606 |
| - ) |
607 |
| - |
608 |
| - ret_types = [ |
609 |
| - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") |
610 |
| - for dtype in output_dtypes |
611 |
| - ] |
612 |
| - |
613 |
| - for output_idx, input_idx in inplace_pattern: |
614 |
| - ret_types[output_idx] = input_types[input_idx] |
615 |
| - |
616 |
| - ret_type = types.Tuple(ret_types) |
617 |
| - |
618 |
| - if len(output_dtypes) == 1: |
619 |
| - ret_type = ret_type.types[0] |
620 |
| - sig = ret_type(*arg_types) |
621 |
| - |
622 |
| - return sig, codegen |
623 |
| - |
624 |
| - |
625 | 474 | @numba_funcify.register(Elemwise)
|
626 | 475 | def numba_funcify_Elemwise(op, node, **kwargs):
|
627 | 476 | # Creating a new scalar node is more involved and unnecessary
|
@@ -665,6 +514,7 @@ def elemwise_wrapper(*inputs):
|
665 | 514 | output_bc_patterns_enc,
|
666 | 515 | output_dtypes_enc,
|
667 | 516 | inplace_pattern_enc,
|
| 517 | + (), |
668 | 518 | inputs,
|
669 | 519 | )
|
670 | 520 |
|
|
0 commit comments