Skip to content

Reconsider checking for input alias during function calls #1026

Open
@ricardoV94

Description

@ricardoV94

Description

In trying to simplify Function.__call__, (see #1024 and #222), I noticed some complicated logic to check if inputs marked as mutable (or borrowable) are not aliasing to the same memory of each other.

if (
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
# Collect aliased inputs among the storage space
args_share_memory = []
for i in range(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):
group_j = zip(
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
if any(
(
var.type is i_var.type
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0]
)

To avoid erroneous computation, __call__ tries to copy aliased inputs. However this logic is wrong because it assumes only variables with the same type can be aliased which doesn't make sense. See the example below where a matrix and a vector are aliased, which fails the check and return wrong values and corrupted input y which was not marked as mutable

import pytensor
import pytensor.tensor as pt
from pytensor import In
import numpy as np

x = pt.vector()
y = pt.matrix()

fn = pytensor.function([In(x, mutable=True), In(y, mutable=False)], [x * 2, y * 2])
fn.dprint(print_destroy_map=True)
# Mul [id A] d={0: [1]} 0
#  ├─ [2.] [id B]
#  └─ <Vector(float64, shape=(?,))> [id C]
# Mul [id D] d={0: [1]} 1
#  ├─ [[2.]] [id E]
#  └─ <Matrix(float64, shape=(?, ?))> [id F]

y_val = np.ones((2, 5))
x_val = y_val[0]  # x is an alias of y
res1, res2 = fn(x_val, y_val)
print(res1)
# [2. 2. 2. 2. 2.]

print(res2)  # Wrong
# [[4. 4. 4. 4. 4.]
#  [2. 2. 2. 2. 2.]]

print(y_val)  # Corrupted
# [[2. 2. 2. 2. 2.]
#  [1. 1. 1. 1. 1.]]

My suggestion is not to make the check for alias more robust (and therefore increase the Function call overhead), but instead to forego it completely. If users indicated that an input is mutable it shouldn't be too surprising that views of that input (or other variables sharing the same underlying memory) would also be corrupted.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions