Skip to content

Avoid PyTensor function overhead in OpFromGraph #1375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 108 additions & 6 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from functools import partial
from typing import Union, cast

from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile import get_default_mode, insert_deepcopy
from pytensor.compile.function.pfunc import pfunc, rebuild_collect_shared
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, grad
Expand All @@ -21,7 +24,7 @@
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import MissingInputError

Expand Down Expand Up @@ -433,6 +436,9 @@ def __init__(
assert isinstance(name, str), "name must be None or string object"
self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}
self._rewritten_fgraph = {}
self._wrapped_inputs = {}
self._wrapped_outputs = {}

def __eq__(self, other):
# TODO: recognize a copy
Expand Down Expand Up @@ -847,14 +853,58 @@ def infer_shape(self, fgraph, node, shapes):

return ret

def _rewrite_fgraph(self, impl):
if self._rewritten_fgraph.get(impl, None) is None:
mode = get_default_mode()
if impl == "py":
mode = mode.excluding("cxx")
rewriter = mode.optimizer

# We are cloning fgraph too many times, but one of the existing tests checks for this
# TestOpFromGraph.test_outputs_consistency
fgraph = self.fgraph.clone()
self._wrapped_inputs[impl] = temp_wrapped_inputs = [
In(inp, borrow=False, mutable=False) for inp in fgraph.inputs
]
# These are just temporary because the graph rewirite may change them
temp_wrapped_outputs = [
Out(out, borrow=True) for out in self.fgraph.outputs
]
add_supervisor_to_fgraph(
fgraph,
temp_wrapped_inputs,
accept_inplace=False,
)
with config.change_flags(compute_test_value="off"):
rewriter(fgraph)
insert_deepcopy(fgraph, temp_wrapped_inputs, temp_wrapped_outputs)
self._wrapped_outputs[impl] = [
Out(out, borrow=True) for out in fgraph.outputs
]
self._rewritten_fgraph[impl] = fgraph

return (
self._rewritten_fgraph[impl],
self._wrapped_inputs[impl],
self._wrapped_outputs[impl],
)

@property
def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn

self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
self._fn.trust_input = True
fgraph, wrapped_inputs, wrapped_outputs = self._rewrite_fgraph(impl=None)

self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
mode=Mode(linker=get_default_mode().linker, optimizer=None),
accept_inplace=True,
on_unused_input="ignore",
fgraph=fgraph,
trust_input=True,
)

return self._fn

Expand All @@ -871,6 +921,58 @@ def clone(self):
res.fgraph = res.fgraph.clone()
return res

def prepare_node(
self,
node: Apply,
storage_map: StorageMapType | None,
compute_map: ComputeMapType | None,
impl: str | None,
) -> None:
self._rewrite_fgraph(impl)
self.fn

def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
from pytensor.link.vm import VMLinker

self.prepare_node(node, storage_map, compute_map, impl)
fg, _, _ = self._rewrite_fgraph(impl)
fg_no_recycling = [
new_o
for (new_o, old_o) in zip(fg.outputs, node.outputs, strict=True)
if old_o in no_recycling
]

node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_compute_map = [compute_map[r] for r in node.outputs]

def create_thunk(linker):
linker.accept(fg, no_recycling=fg_no_recycling)
thunk, _, _ = linker.make_thunk(
input_storage=node_input_storage,
output_storage=node_output_storage,
)
return thunk

def thunk_wrapper(thunk=thunk, node_compute_map=node_compute_map):
thunk()
for cm in node_compute_map:
cm[0] = True

return thunk_wrapper

if impl != "py":
# try:
# # We default to CLinker because it generates code for the whole graph that the compiler can reason about.
# # Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
# # It also has less overhead
# return create_thunk(linker=CLinker())
# except NotImplementedError:
# # Some Op doesn't have a C implementation, VM it is
return create_thunk(VMLinker(use_cloop=True, c_thunks=True))
else:
return create_thunk(VMLinker(use_cloop=False, c_thunks=False))

def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
Expand Down
17 changes: 2 additions & 15 deletions pytensor/link/c/c_code/lazylinker_c.c
Original file line number Diff line number Diff line change
Expand Up @@ -676,20 +676,7 @@ static int lazy_rec_eval(CLazyLinker *self, Py_ssize_t var_idx, PyObject *one,
// rval is new ref
if (rval) // pycall returned normally (no exception)
{
if (rval == Py_None) {
Py_DECREF(rval); // ignore a return of None
} else if (PyList_Check(rval)) {
PyErr_SetString(PyExc_TypeError,
"non-lazy thunk should return None, not list");
err = 1;
goto pyfail;
} else // don't know what it returned, but it wasn't right.
{
PyErr_SetObject(PyExc_TypeError, rval);
err = 1;
// We don't release rval since we put it in the error above
goto fail;
}
Py_DECREF(rval); // ignore whatever was returned
} else // pycall returned NULL (internal error)
{
err = 1;
Expand Down Expand Up @@ -981,7 +968,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
};

static PyObject *get_version(PyObject *dummy, PyObject *args) {
PyObject *result = PyFloat_FromDouble(0.3);
PyObject *result = PyFloat_FromDouble(0.4);
return result;
}

Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/c/lazylinker_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_logger = logging.getLogger(__file__)

force_compile = False
version = 0.3 # must match constant returned in function get_version()
version = 0.4 # must match constant returned in function get_version()
lazylinker_ext: ModuleType | None = None


Expand Down
6 changes: 1 addition & 5 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,15 +1120,11 @@ def unconditional_constant_folding(fgraph, node):
compute_map[o] = [False]

thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
required = thunk()

# A node whose inputs are all provided should always return successfully
assert not required
thunk()

rval = []
for output in node.outputs:
data = storage_map[output][0]
assert compute_map[output][0], (output, data)

# TODO: `Type` itself should provide an interface for constructing
# instances appropriate for a given constant.
Expand Down
64 changes: 61 additions & 3 deletions tests/compile/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import pytensor.tensor as pt
from pytensor import scan
from pytensor.compile import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
Expand All @@ -15,9 +16,10 @@
grad,
verify_grad,
)
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType, null_type
from pytensor.graph.op import Op
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
Expand Down Expand Up @@ -622,14 +624,15 @@ def test_outputs_consistency(self):
"""Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""

x = scalar("x")
op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")
op = OpFromGraph([x], [x**2 / x])

# Confirm that the inner-graph is as expected
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])

# These outputs of the compiled `op.fgraph` should differ from the
# original, uncompiled `op.fgraph` outputs
fn = op.fn
with config.change_flags(mode="FAST_RUN"):
fn = op.fn
new_inputs = fn.maker.fgraph.inputs
new_outputs = fn.maker.fgraph.outputs
assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])
Expand Down Expand Up @@ -740,3 +743,58 @@ def test_debugprint():

for truth, out in zip(exp_res.split("\n"), lines, strict=True):
assert truth.strip() == out.strip()


@pytest.mark.parametrize("kind", ("ofg", "inlined", "scan"))
@pytest.mark.parametrize("c_op", (True, False), ids=lambda x: f"c_op={x}")
def test_benchmark(c_op, kind, benchmark):
class ExpWithoutC(Op):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.exp(inputs[0])

exp_without_c = ExpWithoutC()

n = 25

def _f(x):
if isinstance(x, np.ndarray):
y = np.exp(x)
else:
if c_op:
y = pt.exp(x)
else:
y = exp_without_c(x)
y /= y.sum()
return y

x = pt.vector("x")

if kind == "ofg":
f = OpFromGraph([x], [_f(x)])
else:
f = _f

if kind == "scan":
# Scan is included for a reference of how bad the overhead can be
outs, _ = scan(fn=f, outputs_info=[x], n_steps=n)
out = outs[-1]
else:
out = x
for i in range(n):
out = f(out)

compiled_fn = function([x], out, trust_input=True, mode="FAST_RUN")
compiled_fn.vm.allow_gc = False

rng = np.random.default_rng(1)
x_test = rng.normal(size=(10,))

res = benchmark(compiled_fn, x_test)

expected_res = x_test
for i in range(n):
expected_res = _f(expected_res)
np.testing.assert_allclose(res, expected_res)
Loading