Skip to content

Remove Unbroadcast Op #1286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions doc/library/tensor/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.

.. function:: shape_padleft(x, n_ones=1)

Reshape `x` by left padding the shape with `n_ones` 1s. Note that all
this new dimension will be broadcastable. To make them non-broadcastable
see the :func:`unbroadcast`.
Reshape `x` by left padding the shape with `n_ones` 1s.
All new dimensions will be broadcastable.

:param x: variable to be reshaped
:type x: any `TensorVariable` (or compatible)
Expand All @@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.

.. function:: shape_padright(x, n_ones=1)

Reshape `x` by right padding the shape with `n_ones` ones. Note that all
this new dimension will be broadcastable. To make them non-broadcastable
see the :func:`unbroadcast`.
Reshape `x` by right padding the shape with `n_ones` ones.
All new dimensions will be broadcastable.

:param x: variable to be reshaped
:type x: any TensorVariable (or compatible)
Expand All @@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.

.. function:: shape_padaxis(t, axis)

Reshape `t` by inserting ``1`` at the dimension `axis`. Note that this new
dimension will be broadcastable. To make it non-broadcastable
see the :func:`unbroadcast`.
Reshape `t` by inserting ``1`` at the dimension `axis`.
All new dimensions will be broadcastable.

:type x: any `TensorVariable` (or compatible)
:param x: variable to be reshaped
Expand Down
8 changes: 1 addition & 7 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,8 @@ def clone_inputs(i):
f" shared_var.type={store_into.type},"
f" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})."
)
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)

raise TypeError(err_msg, err_sug)
raise TypeError(err_msg)
assert store_into.type.is_super(update_val.type)

update_d[store_into] = update_val
Expand Down
3 changes: 1 addition & 2 deletions pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.type import HasDataType, HasShape
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape


if TYPE_CHECKING:
Expand Down Expand Up @@ -481,7 +481,6 @@ def cond_make_inplace(fgraph, node):
Shape,
SpecifyShape,
Reshape,
Unbroadcast,
pt.math.Dot,
pt.math.Max,
pt.math.Argmax,
Expand Down
10 changes: 1 addition & 9 deletions pytensor/link/jax/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.type import TensorType


Expand Down Expand Up @@ -104,11 +104,3 @@ def specifyshape(x, *shape):
return x

return specifyshape


@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

return unbroadcast
10 changes: 0 additions & 10 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Split,
TensorFromScalar,
)
from pytensor.tensor.shape import Unbroadcast


@numba_funcify.register(AllocEmpty)
Expand Down Expand Up @@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)


@numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
@numba_basic.numba_njit
def unbroadcast(x):
return x

return unbroadcast


@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit(inline="always")
Expand Down
10 changes: 1 addition & 9 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape


@pytorch_funcify.register(Reshape)
Expand Down Expand Up @@ -56,11 +56,3 @@ def specifyshape(x, *shape):
return x

return specifyshape


@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

return unbroadcast
10 changes: 5 additions & 5 deletions pytensor/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft, unbroadcast
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.updates import OrderedUpdates

Expand Down Expand Up @@ -748,7 +748,7 @@ def wrap_into_list(x):
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
unbroadcast(shape_padleft(actual_arg), 0),
shape_padleft(actual_arg),
actual_n_steps,
)
)
Expand Down Expand Up @@ -865,13 +865,13 @@ def wrap_into_list(x):
if n_fixed_steps in (1, -1):
for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an
# unbroadcastable dimension; case example : we return an
# extra dimension; case example : we return an
# output for which we want all intermediate. If n_steps is 1
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)
outputs[pos] = shape_padleft(inner_out)

if not return_list and len(outputs) == 1:
outputs = outputs[0]
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def wrap_into_list(x):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
expand_empty(
unbroadcast(shape_padleft(input.variable), 0),
shape_padleft(input.variable),
actual_n_steps,
)
)
Expand Down
3 changes: 1 addition & 2 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def check_broadcast(v1, v2):
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using pytensor.tensor."
"{unbroadcast, specify_broadcastable}."
"them consistent, e.g. using pytensor.tensor.specify_broadcastable."
)
size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate(
Expand Down
130 changes: 67 additions & 63 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, dot, maximum, minimum
from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch
from pytensor.tensor.rewriting.basic import (
broadcasted_by,
constant_folding,
local_useless_switch,
)
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from pytensor.tensor.shape import shape
Expand Down Expand Up @@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements


def _is_default_scan_buffer(x: TensorVariable) -> bool:
node = x.owner

if node is None:
return False

op = node.op
if not (
isinstance(op, IncSubtensor)
and op.set_instead_of_inc
and op.idx_list == [slice(None, ps.int64)]
):
return False

x, y, *_ = node.inputs
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
return False

# The value may have been broadcast to fill in the initial taps.
# If the user specified outputs as:
# x = scalar(); init = alloc(x, 2);
# outputs_info=[init, taps=(-2, -1)]
# Scan will generate an initial buffer that looks like
# alloc_empty(2 + nsteps)[:2].set(alloc(x, 2))
# PyTensor will then rewrite it as:
# alloc_empty(2 + nsteps)[:2].set(x)
# When the initial value (x) is being broadcast by the set_subtensor
# we can't recreate a newly sized buffer working with x alone
# We want to check that:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x):
return False

return True


def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
r"""Graph optimizer that reduces scan memory consumption.

Expand Down Expand Up @@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:

# 3.2 check orphane outputs to see if we can eliminate any
required, not_required = scan_can_remove_outs(node.op, orphane_outs)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# by the inner function .. )

# 3.3. compose replace pairs for those nodes that need not store everything in memory
# (or ar orphan but required by the inner function)
replaced_outs = []
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op_info.n_mit_mot in required:
val = 1
else:
val = _val
if not (isinstance(val, int) and val <= 0 and i not in required):
required_orphan = idx + op_info.n_mit_mot in required
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
# In case the input is still an alloc node, we
# actually have two options:
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
# TODO: commit change below with Razvan
if (
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
# As it happens in set_subtensor(empty(2)[:], 0)
and not (
nw_inputs[offset + idx].ndim
> nw_inputs[offset + idx].owner.inputs[1].ndim
)
):
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = pt.as_tensor_variable(val)
initl = pt.as_tensor_variable(init_l[i])
tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl)
nw_input = expand_empty(_nw_input, tmp_idx)
nw_input = nw_inputs[offset + idx]

# Recreate default buffers with new size
if _is_default_scan_buffer(nw_input):
extra_size = 1 if required_orphan else val - init_l[i]
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice
else:
tmp = pt.as_tensor_variable(val)
initl = pt.as_tensor_variable(init_l[i])
tmp = maximum(tmp, initl)
nw_input = nw_inputs[offset + idx][:tmp]
stop = init_l[i] if required_orphan else val
nw_input = nw_input[:stop]

nw_inputs[offset + idx] = nw_input
replaced_outs.append(op_info.n_mit_mot + idx)
Expand All @@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ op_info.n_shared_outs
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val
nw_inputs[pos] = 1 if required_orphan else val
odx = op_info.n_mit_mot + idx
replaced_outs.append(odx)
old_outputs += [
Expand All @@ -1600,37 +1619,22 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
],
)
]
# 3.4. Recompute inputs for everything else based on the new
# number of steps
# 3.4. Recompute inputs for everything else based on the new number of steps
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op_info.n_mit_mot + idx]

# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if (
nw_inputs[in_idx].owner
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[in_idx].owner.op.idx_list[0], slice
)
):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input
nw_input = nw_inputs[in_idx]
if _is_default_scan_buffer(nw_input):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
else:
# FIXME: This is never used
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
# Number of steps in the initial state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that FIXME in the old code still relevant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with these changes

init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_inputs[in_idx] = nw_input

elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
Expand Down
Loading