Skip to content

Commit a9c52dd

Browse files
committed
Move numba subtensor functionality to its own module
1 parent 0353abe commit a9c52dd

File tree

5 files changed

+463
-433
lines changed

5 files changed

+463
-433
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
import pytensor.link.numba.dispatch.scan
1212
import pytensor.link.numba.dispatch.sparse
1313
import pytensor.link.numba.dispatch.slinalg
14+
import pytensor.link.numba.dispatch.subtensor
1415

1516
# isort: on

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pytensor.link.utils import (
3030
compile_function_src,
3131
fgraph_to_python,
32-
unique_name_generator,
3332
)
3433
from pytensor.scalar.basic import ScalarType
3534
from pytensor.scalar.math import Softplus
@@ -38,14 +37,6 @@
3837
from pytensor.tensor.math import Dot
3938
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
4039
from pytensor.tensor.slinalg import Solve
41-
from pytensor.tensor.subtensor import (
42-
AdvancedIncSubtensor,
43-
AdvancedIncSubtensor1,
44-
AdvancedSubtensor,
45-
AdvancedSubtensor1,
46-
IncSubtensor,
47-
Subtensor,
48-
)
4940
from pytensor.tensor.type import TensorType
5041
from pytensor.tensor.type_other import MakeSlice, NoneConst
5142

@@ -479,217 +470,6 @@ def numba_funcify_FunctionGraph(
479470
)
480471

481472

482-
def create_index_func(node, objmode=False):
483-
"""Create a Python function that assembles and uses an index on an array."""
484-
485-
unique_names = unique_name_generator(
486-
["subtensor", "incsubtensor", "z"], suffix_sep="_"
487-
)
488-
489-
def convert_indices(indices, entry):
490-
if indices and isinstance(entry, Type):
491-
rval = indices.pop(0)
492-
return unique_names(rval)
493-
elif isinstance(entry, slice):
494-
return (
495-
f"slice({convert_indices(indices, entry.start)}, "
496-
f"{convert_indices(indices, entry.stop)}, "
497-
f"{convert_indices(indices, entry.step)})"
498-
)
499-
elif isinstance(entry, type(None)):
500-
return "None"
501-
else:
502-
raise ValueError()
503-
504-
set_or_inc = isinstance(
505-
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506-
)
507-
index_start_idx = 1 + int(set_or_inc)
508-
509-
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
510-
op_indices = list(node.inputs[index_start_idx:])
511-
idx_list = getattr(node.op, "idx_list", None)
512-
513-
indices_creation_src = (
514-
tuple(convert_indices(op_indices, idx) for idx in idx_list)
515-
if idx_list
516-
else tuple(input_names[index_start_idx:])
517-
)
518-
519-
if len(indices_creation_src) == 1:
520-
indices_creation_src = f"indices = ({indices_creation_src[0]},)"
521-
else:
522-
indices_creation_src = ", ".join(indices_creation_src)
523-
indices_creation_src = f"indices = ({indices_creation_src})"
524-
525-
if set_or_inc:
526-
fn_name = "incsubtensor"
527-
if node.op.inplace:
528-
index_prologue = f"z = {input_names[0]}"
529-
else:
530-
index_prologue = f"z = np.copy({input_names[0]})"
531-
532-
if node.inputs[1].ndim == 0:
533-
# TODO FIXME: This is a hack to get around a weird Numba typing
534-
# issue. See https://github.com/numba/numba/issues/6000
535-
y_name = f"{input_names[1]}.item()"
536-
else:
537-
y_name = input_names[1]
538-
539-
if node.op.set_instead_of_inc:
540-
index_body = f"z[indices] = {y_name}"
541-
else:
542-
index_body = f"z[indices] += {y_name}"
543-
else:
544-
fn_name = "subtensor"
545-
index_prologue = ""
546-
index_body = f"z = {input_names[0]}[indices]"
547-
548-
if objmode:
549-
output_var = node.outputs[0]
550-
551-
if not set_or_inc:
552-
# Since `z` is being "created" while in object mode, it's
553-
# considered an "outgoing" variable and needs to be manually typed
554-
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
555-
else:
556-
output_sig = ""
557-
558-
index_body = f"""
559-
with objmode({output_sig}):
560-
{index_body}
561-
"""
562-
563-
subtensor_def_src = f"""
564-
def {fn_name}({", ".join(input_names)}):
565-
{index_prologue}
566-
{indices_creation_src}
567-
{index_body}
568-
return np.asarray(z)
569-
"""
570-
571-
return subtensor_def_src
572-
573-
574-
@numba_funcify.register(Subtensor)
575-
@numba_funcify.register(AdvancedSubtensor1)
576-
def numba_funcify_Subtensor(op, node, **kwargs):
577-
objmode = isinstance(op, AdvancedSubtensor)
578-
if objmode:
579-
warnings.warn(
580-
("Numba will use object mode to allow run " "AdvancedSubtensor."),
581-
UserWarning,
582-
)
583-
584-
subtensor_def_src = create_index_func(node, objmode=objmode)
585-
586-
global_env = {"np": np}
587-
if objmode:
588-
global_env["objmode"] = numba.objmode
589-
590-
subtensor_fn = compile_function_src(
591-
subtensor_def_src, "subtensor", {**globals(), **global_env}
592-
)
593-
594-
return numba_njit(subtensor_fn, boundscheck=True)
595-
596-
597-
@numba_funcify.register(IncSubtensor)
598-
def numba_funcify_IncSubtensor(op, node, **kwargs):
599-
objmode = isinstance(op, AdvancedIncSubtensor)
600-
if objmode:
601-
warnings.warn(
602-
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
603-
UserWarning,
604-
)
605-
606-
incsubtensor_def_src = create_index_func(node, objmode=objmode)
607-
608-
global_env = {"np": np}
609-
if objmode:
610-
global_env["objmode"] = numba.objmode
611-
612-
incsubtensor_fn = compile_function_src(
613-
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
614-
)
615-
616-
return numba_njit(incsubtensor_fn, boundscheck=True)
617-
618-
619-
@numba_funcify.register(AdvancedIncSubtensor1)
620-
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
621-
inplace = op.inplace
622-
set_instead_of_inc = op.set_instead_of_inc
623-
x, vals, idxs = node.inputs
624-
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
625-
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
626-
627-
if set_instead_of_inc:
628-
if broadcast:
629-
630-
@numba_njit(boundscheck=True)
631-
def advancedincsubtensor1_inplace(x, val, idxs):
632-
if val.ndim == x.ndim:
633-
core_val = val[0]
634-
elif val.ndim == 0:
635-
# Workaround for https://github.com/numba/numba/issues/9573
636-
core_val = val.item()
637-
else:
638-
core_val = val
639-
640-
for idx in idxs:
641-
x[idx] = core_val
642-
return x
643-
644-
else:
645-
646-
@numba_njit(boundscheck=True)
647-
def advancedincsubtensor1_inplace(x, vals, idxs):
648-
if not len(idxs) == len(vals):
649-
raise ValueError("The number of indices and values must match.")
650-
for idx, val in zip(idxs, vals):
651-
x[idx] = val
652-
return x
653-
else:
654-
if broadcast:
655-
656-
@numba_njit(boundscheck=True)
657-
def advancedincsubtensor1_inplace(x, val, idxs):
658-
if val.ndim == x.ndim:
659-
core_val = val[0]
660-
elif val.ndim == 0:
661-
# Workaround for https://github.com/numba/numba/issues/9573
662-
core_val = val.item()
663-
else:
664-
core_val = val
665-
666-
for idx in idxs:
667-
x[idx] += core_val
668-
return x
669-
670-
else:
671-
672-
@numba_njit(boundscheck=True)
673-
def advancedincsubtensor1_inplace(x, vals, idxs):
674-
if not len(idxs) == len(vals):
675-
raise ValueError("The number of indices and values must match.")
676-
for idx, val in zip(idxs, vals):
677-
x[idx] += val
678-
return x
679-
680-
if inplace:
681-
return advancedincsubtensor1_inplace
682-
683-
else:
684-
685-
@numba_njit
686-
def advancedincsubtensor1(x, vals, idxs):
687-
x = x.copy()
688-
return advancedincsubtensor1_inplace(x, vals, idxs)
689-
690-
return advancedincsubtensor1
691-
692-
693473
def deepcopyop(x):
694474
return copy(x)
695475

0 commit comments

Comments
 (0)