Skip to content

Inline constants in composite graphs #361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ def c_literal(self, data):
return None
if self.dtype == "bool":
return "1" if data else "0"
if data == np.inf:
return "INFINITY"
if data == -np.inf:
return "-INFINITY"
if np.isnan(data):
return "NAN"
return str(data)

def c_declare(self, name, sub, check_input=True):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.var import TensorConstant, get_unique_value
from pytensor.tensor.var import TensorConstant, get_unique_constant_value


list_opt_slice = [
Expand Down Expand Up @@ -136,7 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
node_inp = node.inputs[idx + 1]
if (
isinstance(node_inp, TensorConstant)
and get_unique_value(node_inp) is not None
and get_unique_constant_value(node_inp) is not None
):
try:
# This works if input is a constant that has all entries
Expand Down
8 changes: 6 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@
uint_dtypes,
values_eq_approx_always_true,
)
from pytensor.tensor.var import TensorConstant, TensorVariable, get_unique_value
from pytensor.tensor.var import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -323,7 +327,7 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError()

if isinstance(v, Constant):
unique_value = get_unique_value(v)
unique_value = get_unique_constant_value(v)
if unique_value is not None:
data = unique_value
else:
Expand Down
58 changes: 56 additions & 2 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
from pytensor.tensor.rewriting.basic import (
broadcast_like,
register_canonicalize,
register_specialize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant
from pytensor.tensor.var import TensorConstant, get_unique_constant_value


class InplaceElemwiseOptimizer(GraphRewriter):
Expand Down Expand Up @@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)]


@node_rewriter([Elemwise])
def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op

if not isinstance(composite_op, aes.Composite):
return None

new_outer_inputs = []
new_inner_inputs = []
inner_replacements = {}
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
# Complex variables don't have a `c_literal` that can be inlined
if "complex" not in outer_inp.type.dtype:
unique_value = get_unique_constant_value(outer_inp)
if unique_value is not None:
inner_replacements[inner_inp] = aes.constant(
unique_value, dtype=inner_inp.dtype
)
continue
new_outer_inputs.append(outer_inp)
new_inner_inputs.append(inner_inp)

if not inner_replacements:
return None

new_inner_outs = clone_replace(
composite_op.fgraph.outputs, replace=inner_replacements
)
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs

# Some of the inlined constants were broadcasting the output shape
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could this happen if we only changed scalar inputs?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_unique_constant_value works for homogeneous constants of any rank

new_outputs = [
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
for new_out in new_outputs
]

copy_stack_trace(node.outputs, new_outputs)
return new_outputs


# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register(
Expand Down Expand Up @@ -1243,6 +1290,13 @@ def local_careduce_fusion(fgraph, node):
"fusion",
position=10,
)
fuse_seqopt.register(
"local_inline_composite_constants",
in2out(local_inline_composite_constants),
"fast_run",
"fusion",
position=20,
)


def _rebuild_partial_2f1grad_loop(node, wrt):
Expand Down
10 changes: 6 additions & 4 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan,
)
from pytensor.tensor.var import TensorConstant, get_unique_value
from pytensor.tensor.var import TensorConstant, get_unique_constant_value


def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_constant(v):

"""
if isinstance(v, Constant):
unique_value = get_unique_value(v)
unique_value = get_unique_constant_value(v)
if unique_value is not None:
data = unique_value
else:
Expand Down Expand Up @@ -1135,10 +1135,12 @@ def same(x, y):
if new.type.dtype != out.type.dtype:
new = cast(new, out.type.dtype)

if new.type != out.type:
if new.type.broadcastable != out.type.broadcastable:
new = fill_chain(new, node.inputs)[0]

if new.type == out.type:
if (new.type.dtype == out.type.dtype) and (
new.type.broadcastable == out.type.broadcastable
):
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
copy_stack_trace(out, new)
return [new]
Expand Down
46 changes: 46 additions & 0 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,52 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return [stack(shape).astype(np.int64)]


@register_canonicalize
@register_specialize
@node_rewriter([SpecifyShape])
def local_specify_shape_lift(fgraph, node):
"""Lift SpecifyShape of Elemwise towards the inputs."""
inp, *shape = node.inputs
if inp.owner and isinstance(inp.owner.op, Elemwise):
if len(inp.owner.outputs) != 1:
return None

elem_inps = inp.owner.inputs
if len(elem_inps) == 1:
new_elem_inps = [specify_shape(elem_inps[0], shape)]
else:
# Rewrite does not support case where specify_shape provides new broadcastable information,
# As that may require a specify_shape for each input
out_broadcastable = node.outputs[0].type.broadcastable
if out_broadcastable != inp.type.broadcastable:
return None

# All non-broadcastable dimensions of inputs must match the non-broadcastbale specify_shape dims
# We look for a sufficient input to assign all the specify_shape dims
# We could consider distributing the SpecifyShape across multiple inputs, when none is sufficient

nonbcast_dims = {
i
for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable))
if (not bcast and not NoneConst.equals(dim))
}
new_elem_inps = elem_inps.copy()
for i, elem_inp in enumerate(elem_inps):
if all(
bcast_dim is False
for dim, bcast_dim in enumerate(elem_inp.type.broadcastable)
if dim in nonbcast_dims
):
new_elem_inps[i] = specify_shape(elem_inp, shape)
break
else: # no-break, no sufficient candidate found
return None

new_out = inp.owner.op.make_node(*new_elem_inps).outputs
copy_stack_trace(node.outputs, new_out)
return new_out


@register_useless
@register_canonicalize
@node_rewriter([Shape_i])
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def no_nan(self):
return self._no_nan


def get_unique_value(x: TensorVariable) -> Optional[Number]:
def get_unique_constant_value(x: TensorVariable) -> Optional[Number]:
"""Return the unique value of a tensor, if there is one"""
if isinstance(x, Constant):
data = x.data
Expand Down
22 changes: 17 additions & 5 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,33 @@ def test_flatten(self):
# We don't flatten that case.
assert isinstance(CC.outputs[0].owner.op, Composite)

def test_with_constants(self):
@pytest.mark.parametrize("literal_value", (70.0, -np.inf, np.float32("nan")))
def test_with_constants(self, literal_value):
x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y))
e = mul(add(literal_value, y), true_div(x, y))
comp_op = Composite([x, y], [e])
comp_node = comp_op.make_node(x, y)

c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
assert "70.0" in c_code
assert constant(literal_value).type.c_literal(literal_value) in c_code

# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")

g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0

# Default checker does not allow `nan`
def checker(x, y):
np.testing.assert_equal(x, y)

fn = make_function(DualLinker(checker=checker).accept(g))

test_x = 1.0
test_y = 2.0
np.testing.assert_equal(
fn(test_x, test_y),
(literal_value + test_y) * (test_x / test_y),
)

def test_many_outputs(self):
x, y, z = floats("xyz")
Expand Down
26 changes: 26 additions & 0 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs():
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])


@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
def test_local_inline_composite_constants(op, np_op, const_shape):
const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX)
x = vector("x")
y = vector("y")
out = at.exp(op(x, const)) + y

fn = pytensor.function(
[x, y], out, mode=get_default_mode().including("specialize", "fusion")
)
# There should be a single Composite after optimization
[node] = [
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
]
assert isinstance(node.op.scalar_op, Composite)
assert len(node.inputs) == 2 # x and y, but not const

x_test_value = np.arange(5).astype(config.floatX)
y_test_value = np.ones(5).astype(config.floatX)
np.testing.assert_allclose(
fn(x_test_value, y_test_value),
np.exp(np_op(x_test_value, const)) + y_test_value,
)


def test_local_useless_dimshuffle_makevector():
a = scalar()
x = MakeVector(config.floatX)(a)
Expand Down
26 changes: 24 additions & 2 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, switch
from pytensor.tensor.basic import Alloc, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down Expand Up @@ -96,7 +96,7 @@
perform_sigm_times_exp,
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
from pytensor.tensor.type import (
TensorType,
cmatrix,
Expand Down Expand Up @@ -979,6 +979,28 @@ def test_mismatching_types(self):
# No rewrite was applied
assert z_rewritten is z

def test_shape_specified_by_constant(self):
x = vector("x")
const = np.full(shape=(5,), fill_value=2.0).astype(config.floatX)
out = x * const

new_out = rewrite_graph(
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
)
expected_out = np.array([2.0]).astype(config.floatX) * specify_shape(x, (5,))
assert equal_computations([new_out], [expected_out])

def test_broadcasted_by_constant(self):
x = vector("x")
const = np.full(shape=(3, 5), fill_value=2.0).astype(config.floatX)
out = x * const

new_out = rewrite_graph(
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
)
expected_out = second(const, np.array([[2.0]], dtype=config.floatX) * x)
assert equal_computations([new_out], [expected_out])


def test_local_merge_abs():
x, y, z = matrices("xyz")
Expand Down
8 changes: 8 additions & 0 deletions tests/tensor/rewriting/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,14 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)


def test_local_specify_shape_lift():
x = vector("x")
out = specify_shape([1.0] + x, shape=(5,))

new_out = rewrite_graph(out)
assert equal_computations([new_out], [[1.0] + specify_shape(x, shape=(5,))])


def test_local_Shape_i_ground():
x = tensor(dtype=np.float64, shape=(None, 2))
s = Shape_i(1)(x)
Expand Down