|
29 | 29 | from pytensor.link.utils import (
|
30 | 30 | compile_function_src,
|
31 | 31 | fgraph_to_python,
|
32 |
| - unique_name_generator, |
33 | 32 | )
|
34 | 33 | from pytensor.scalar.basic import ScalarType
|
35 | 34 | from pytensor.scalar.math import Softplus
|
|
38 | 37 | from pytensor.tensor.math import Dot
|
39 | 38 | from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
|
40 | 39 | from pytensor.tensor.slinalg import Solve
|
41 |
| -from pytensor.tensor.subtensor import ( |
42 |
| - AdvancedIncSubtensor, |
43 |
| - AdvancedIncSubtensor1, |
44 |
| - AdvancedSubtensor, |
45 |
| - AdvancedSubtensor1, |
46 |
| - IncSubtensor, |
47 |
| - Subtensor, |
48 |
| -) |
49 | 40 | from pytensor.tensor.type import TensorType
|
50 | 41 | from pytensor.tensor.type_other import MakeSlice, NoneConst
|
51 | 42 |
|
@@ -479,217 +470,6 @@ def numba_funcify_FunctionGraph(
|
479 | 470 | )
|
480 | 471 |
|
481 | 472 |
|
482 |
| -def create_index_func(node, objmode=False): |
483 |
| - """Create a Python function that assembles and uses an index on an array.""" |
484 |
| - |
485 |
| - unique_names = unique_name_generator( |
486 |
| - ["subtensor", "incsubtensor", "z"], suffix_sep="_" |
487 |
| - ) |
488 |
| - |
489 |
| - def convert_indices(indices, entry): |
490 |
| - if indices and isinstance(entry, Type): |
491 |
| - rval = indices.pop(0) |
492 |
| - return unique_names(rval) |
493 |
| - elif isinstance(entry, slice): |
494 |
| - return ( |
495 |
| - f"slice({convert_indices(indices, entry.start)}, " |
496 |
| - f"{convert_indices(indices, entry.stop)}, " |
497 |
| - f"{convert_indices(indices, entry.step)})" |
498 |
| - ) |
499 |
| - elif isinstance(entry, type(None)): |
500 |
| - return "None" |
501 |
| - else: |
502 |
| - raise ValueError() |
503 |
| - |
504 |
| - set_or_inc = isinstance( |
505 |
| - node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor |
506 |
| - ) |
507 |
| - index_start_idx = 1 + int(set_or_inc) |
508 |
| - |
509 |
| - input_names = [unique_names(v, force_unique=True) for v in node.inputs] |
510 |
| - op_indices = list(node.inputs[index_start_idx:]) |
511 |
| - idx_list = getattr(node.op, "idx_list", None) |
512 |
| - |
513 |
| - indices_creation_src = ( |
514 |
| - tuple(convert_indices(op_indices, idx) for idx in idx_list) |
515 |
| - if idx_list |
516 |
| - else tuple(input_names[index_start_idx:]) |
517 |
| - ) |
518 |
| - |
519 |
| - if len(indices_creation_src) == 1: |
520 |
| - indices_creation_src = f"indices = ({indices_creation_src[0]},)" |
521 |
| - else: |
522 |
| - indices_creation_src = ", ".join(indices_creation_src) |
523 |
| - indices_creation_src = f"indices = ({indices_creation_src})" |
524 |
| - |
525 |
| - if set_or_inc: |
526 |
| - fn_name = "incsubtensor" |
527 |
| - if node.op.inplace: |
528 |
| - index_prologue = f"z = {input_names[0]}" |
529 |
| - else: |
530 |
| - index_prologue = f"z = np.copy({input_names[0]})" |
531 |
| - |
532 |
| - if node.inputs[1].ndim == 0: |
533 |
| - # TODO FIXME: This is a hack to get around a weird Numba typing |
534 |
| - # issue. See https://github.com/numba/numba/issues/6000 |
535 |
| - y_name = f"{input_names[1]}.item()" |
536 |
| - else: |
537 |
| - y_name = input_names[1] |
538 |
| - |
539 |
| - if node.op.set_instead_of_inc: |
540 |
| - index_body = f"z[indices] = {y_name}" |
541 |
| - else: |
542 |
| - index_body = f"z[indices] += {y_name}" |
543 |
| - else: |
544 |
| - fn_name = "subtensor" |
545 |
| - index_prologue = "" |
546 |
| - index_body = f"z = {input_names[0]}[indices]" |
547 |
| - |
548 |
| - if objmode: |
549 |
| - output_var = node.outputs[0] |
550 |
| - |
551 |
| - if not set_or_inc: |
552 |
| - # Since `z` is being "created" while in object mode, it's |
553 |
| - # considered an "outgoing" variable and needs to be manually typed |
554 |
| - output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'" |
555 |
| - else: |
556 |
| - output_sig = "" |
557 |
| - |
558 |
| - index_body = f""" |
559 |
| - with objmode({output_sig}): |
560 |
| - {index_body} |
561 |
| - """ |
562 |
| - |
563 |
| - subtensor_def_src = f""" |
564 |
| -def {fn_name}({", ".join(input_names)}): |
565 |
| - {index_prologue} |
566 |
| - {indices_creation_src} |
567 |
| - {index_body} |
568 |
| - return np.asarray(z) |
569 |
| - """ |
570 |
| - |
571 |
| - return subtensor_def_src |
572 |
| - |
573 |
| - |
574 |
| -@numba_funcify.register(Subtensor) |
575 |
| -@numba_funcify.register(AdvancedSubtensor1) |
576 |
| -def numba_funcify_Subtensor(op, node, **kwargs): |
577 |
| - objmode = isinstance(op, AdvancedSubtensor) |
578 |
| - if objmode: |
579 |
| - warnings.warn( |
580 |
| - ("Numba will use object mode to allow run " "AdvancedSubtensor."), |
581 |
| - UserWarning, |
582 |
| - ) |
583 |
| - |
584 |
| - subtensor_def_src = create_index_func(node, objmode=objmode) |
585 |
| - |
586 |
| - global_env = {"np": np} |
587 |
| - if objmode: |
588 |
| - global_env["objmode"] = numba.objmode |
589 |
| - |
590 |
| - subtensor_fn = compile_function_src( |
591 |
| - subtensor_def_src, "subtensor", {**globals(), **global_env} |
592 |
| - ) |
593 |
| - |
594 |
| - return numba_njit(subtensor_fn, boundscheck=True) |
595 |
| - |
596 |
| - |
597 |
| -@numba_funcify.register(IncSubtensor) |
598 |
| -def numba_funcify_IncSubtensor(op, node, **kwargs): |
599 |
| - objmode = isinstance(op, AdvancedIncSubtensor) |
600 |
| - if objmode: |
601 |
| - warnings.warn( |
602 |
| - ("Numba will use object mode to allow run " "AdvancedIncSubtensor."), |
603 |
| - UserWarning, |
604 |
| - ) |
605 |
| - |
606 |
| - incsubtensor_def_src = create_index_func(node, objmode=objmode) |
607 |
| - |
608 |
| - global_env = {"np": np} |
609 |
| - if objmode: |
610 |
| - global_env["objmode"] = numba.objmode |
611 |
| - |
612 |
| - incsubtensor_fn = compile_function_src( |
613 |
| - incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} |
614 |
| - ) |
615 |
| - |
616 |
| - return numba_njit(incsubtensor_fn, boundscheck=True) |
617 |
| - |
618 |
| - |
619 |
| -@numba_funcify.register(AdvancedIncSubtensor1) |
620 |
| -def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): |
621 |
| - inplace = op.inplace |
622 |
| - set_instead_of_inc = op.set_instead_of_inc |
623 |
| - x, vals, idxs = node.inputs |
624 |
| - # TODO: Add explicit expand_dims in make_node so we don't need to worry about this here |
625 |
| - broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] |
626 |
| - |
627 |
| - if set_instead_of_inc: |
628 |
| - if broadcast: |
629 |
| - |
630 |
| - @numba_njit(boundscheck=True) |
631 |
| - def advancedincsubtensor1_inplace(x, val, idxs): |
632 |
| - if val.ndim == x.ndim: |
633 |
| - core_val = val[0] |
634 |
| - elif val.ndim == 0: |
635 |
| - # Workaround for https://github.com/numba/numba/issues/9573 |
636 |
| - core_val = val.item() |
637 |
| - else: |
638 |
| - core_val = val |
639 |
| - |
640 |
| - for idx in idxs: |
641 |
| - x[idx] = core_val |
642 |
| - return x |
643 |
| - |
644 |
| - else: |
645 |
| - |
646 |
| - @numba_njit(boundscheck=True) |
647 |
| - def advancedincsubtensor1_inplace(x, vals, idxs): |
648 |
| - if not len(idxs) == len(vals): |
649 |
| - raise ValueError("The number of indices and values must match.") |
650 |
| - for idx, val in zip(idxs, vals): |
651 |
| - x[idx] = val |
652 |
| - return x |
653 |
| - else: |
654 |
| - if broadcast: |
655 |
| - |
656 |
| - @numba_njit(boundscheck=True) |
657 |
| - def advancedincsubtensor1_inplace(x, val, idxs): |
658 |
| - if val.ndim == x.ndim: |
659 |
| - core_val = val[0] |
660 |
| - elif val.ndim == 0: |
661 |
| - # Workaround for https://github.com/numba/numba/issues/9573 |
662 |
| - core_val = val.item() |
663 |
| - else: |
664 |
| - core_val = val |
665 |
| - |
666 |
| - for idx in idxs: |
667 |
| - x[idx] += core_val |
668 |
| - return x |
669 |
| - |
670 |
| - else: |
671 |
| - |
672 |
| - @numba_njit(boundscheck=True) |
673 |
| - def advancedincsubtensor1_inplace(x, vals, idxs): |
674 |
| - if not len(idxs) == len(vals): |
675 |
| - raise ValueError("The number of indices and values must match.") |
676 |
| - for idx, val in zip(idxs, vals): |
677 |
| - x[idx] += val |
678 |
| - return x |
679 |
| - |
680 |
| - if inplace: |
681 |
| - return advancedincsubtensor1_inplace |
682 |
| - |
683 |
| - else: |
684 |
| - |
685 |
| - @numba_njit |
686 |
| - def advancedincsubtensor1(x, vals, idxs): |
687 |
| - x = x.copy() |
688 |
| - return advancedincsubtensor1_inplace(x, vals, idxs) |
689 |
| - |
690 |
| - return advancedincsubtensor1 |
691 |
| - |
692 |
| - |
693 | 473 | def deepcopyop(x):
|
694 | 474 | return copy(x)
|
695 | 475 |
|
|
0 commit comments