Description
Description
In trying to simplify Function.__call__
, (see #1024 and #222), I noticed some complicated logic to check if inputs marked as mutable (or borrowable) are not aliasing to the same memory of each other.
pytensor/pytensor/compile/function/types.py
Lines 888 to 933 in be358ed
To avoid erroneous computation, __call__
tries to copy aliased inputs. However this logic is wrong because it assumes only variables with the same type can be aliased which doesn't make sense. See the example below where a matrix and a vector are aliased, which fails the check and return wrong values and corrupted input y
which was not marked as mutable
import pytensor
import pytensor.tensor as pt
from pytensor import In
import numpy as np
x = pt.vector()
y = pt.matrix()
fn = pytensor.function([In(x, mutable=True), In(y, mutable=False)], [x * 2, y * 2])
fn.dprint(print_destroy_map=True)
# Mul [id A] d={0: [1]} 0
# ├─ [2.] [id B]
# └─ <Vector(float64, shape=(?,))> [id C]
# Mul [id D] d={0: [1]} 1
# ├─ [[2.]] [id E]
# └─ <Matrix(float64, shape=(?, ?))> [id F]
y_val = np.ones((2, 5))
x_val = y_val[0] # x is an alias of y
res1, res2 = fn(x_val, y_val)
print(res1)
# [2. 2. 2. 2. 2.]
print(res2) # Wrong
# [[4. 4. 4. 4. 4.]
# [2. 2. 2. 2. 2.]]
print(y_val) # Corrupted
# [[2. 2. 2. 2. 2.]
# [1. 1. 1. 1. 1.]]
My suggestion is not to make the check for alias more robust (and therefore increase the Function call overhead), but instead to forego it completely. If users indicated that an input is mutable it shouldn't be too surprising that views of that input (or other variables sharing the same underlying memory) would also be corrupted.