Skip to content

Dont check input alias on Function call #1049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/library/compile/function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ Reference

.. attribute:: mutable

``True`` means the compiled-function is allowed to modify this
argument. ``False`` means it is not allowed.
Defaults to ``True`` if ``update`` is not ``None``, ``False`` otherwise.
When ``True``, permit the compiled function to modify the python object being passed as the input to save memory.
When an input is mutable, it shouldn't be aliased (a view) of any other input. Otherwise, behavior is undefined, and will likely yield wrong results.

.. attribute:: borrow

Expand Down
222 changes: 80 additions & 142 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@
def __init__(
self,
vm: "VM",
input_storage,
output_storage,
input_storage: list[Container],
output_storage: list[Container],
indices,
outputs,
defaults,
Expand Down Expand Up @@ -372,7 +372,6 @@
name
A string name.
"""
# TODO: Rename to `vm`
self.vm = vm
self.input_storage = input_storage
self.output_storage = output_storage
Expand All @@ -388,31 +387,17 @@
self.nodes_with_inner_function = []
self.output_keys = output_keys

# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more than one input
self._check_for_aliased_inputs = False
for i in maker.inputs:
# If the input is a shared variable, the memory region is
# under PyTensor control and so we don't need to check if it
# is aliased as we never do that.
if (
isinstance(i, In)
and not i.shared
and (getattr(i, "borrow", False) or getattr(i, "mutable", False))
):
self._check_for_aliased_inputs = True
break
if self.output_keys is not None:
warnings.warn("output_keys is deprecated.", FutureWarning)

assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)

# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
finder = {}
inv_finder = {}

def distribute(indices, cs, value):
input.distribute(value, indices, cs)
for c in cs:
c.provided += 1

# Store the list of names of named inputs.
named_inputs = []
# Count the number of un-named inputs.
Expand Down Expand Up @@ -777,6 +762,13 @@
f_cpy.maker.fgraph.name = name
return f_cpy

def _restore_defaults(self):
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]

Check warning on line 769 in pytensor/compile/function/types.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/function/types.py#L769

Added line #L769 was not covered by tests
self[i] = value

def __call__(self, *args, **kwargs):
"""
Evaluates value of a function on given arguments.
Expand Down Expand Up @@ -805,52 +797,52 @@
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""

def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value

input_storage = self.input_storage
profile = self.profile
t0 = time.perf_counter()

if profile:
t0 = time.perf_counter()

output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]
if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
for arg in args:
s = self.input_storage[i]
s.storage[0] = arg
i += 1
# Set positional arguments
for arg_container, arg in zip(input_storage, args, strict=False):
arg_container.storage[0] = arg

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

Check warning on line 820 in pytensor/compile/function/types.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/function/types.py#L820

Added line #L820 was not covered by tests

else:
for c in self.input_storage:
c.provided = 0
# Reinitialize each container's 'provided' counter
for arg_container in input_storage:
arg_container.provided = 0

if len(args) + len(kwargs) > len(self.input_storage):
if len(args) + len(kwargs) > len(input_storage):
raise TypeError("Too many parameter passed to pytensor function")

# Set positional arguments
i = 0
for arg in args:
# TODO: provide a option for skipping the filter if we really
# want speed.
s = self.input_storage[i]
# see this emails for a discuation about None as input
for arg_container, arg in zip(input_storage, args, strict=False):
# See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None:
s.storage[0] = arg
arg_container.storage[0] = arg
else:
try:
s.storage[0] = s.type.filter(
arg, strict=s.strict, allow_downcast=s.allow_downcast
arg_container.storage[0] = arg_container.type.filter(
arg,
strict=arg_container.strict,
allow_downcast=arg_container.allow_downcast,
)

except Exception as e:
i = input_storage.index(arg_container)
function_name = "pytensor function"
argument_name = "argument"
if self.name:
Expand All @@ -875,93 +867,45 @@
+ function_name
+ f" at index {int(i)} (0-based). {where}"
) + e.args
restore_defaults()
self._restore_defaults()
raise
s.provided += 1
i += 1

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

if (
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
# Collect aliased inputs among the storage space
args_share_memory = []
for i in range(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):
group_j = zip(
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
if any(
(
var.type is i_var.type
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
):
is_aliased = True
args_share_memory[j].append(i)
break

if not is_aliased:
args_share_memory.append([i])

# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0]
)

# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
if not self.trust_input:
for c in self.input_storage:
if c.required and not c.provided:
restore_defaults()
arg_container.provided += 1

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
for arg_container in input_storage:
if arg_container.required and not arg_container.provided:
self._restore_defaults()
raise TypeError(
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
)
if c.provided > 1:
restore_defaults()
if arg_container.provided > 1:
self._restore_defaults()
raise TypeError(
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
)
if c.implicit and c.provided > 0:
restore_defaults()
if arg_container.implicit and arg_container.provided > 0:
self._restore_defaults()
raise TypeError(
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
)

# Do the actual work
t0_fn = time.perf_counter()
if profile:
t0_fn = time.perf_counter()
try:
outputs = (
self.vm()
if output_subset is None
else self.vm(output_subset=output_subset)
)
except Exception:
restore_defaults()
self._restore_defaults()
if hasattr(self.vm, "position_of_error"):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
Expand All @@ -979,26 +923,24 @@
# old-style linkers raise their own exceptions
raise

dt_fn = time.perf_counter() - t0_fn
self.maker.mode.fn_time += dt_fn
if profile:
dt_fn = time.perf_counter() - t0_fn
self.maker.mode.fn_time += dt_fn
profile.vm_call_time += dt_fn

# Retrieve the values that were computed
if outputs is None:
outputs = [x.data for x in self.output_storage]
assert len(outputs) == len(self.output_storage)

# Remove internal references to required inputs.
# These cannot be re-used anyway.
for c in self.input_storage:
if c.required:
c.storage[0] = None
for arg_container in input_storage:
if arg_container.required:
arg_container.storage[0] = None

# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs
):
Expand All @@ -1007,37 +949,31 @@
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None

# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, self.input_storage))
list(zip(self.maker.expanded_inputs, input_storage))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]

# Put default values back in the storage
restore_defaults()
#
# NOTE: This logic needs to be replicated in
# scan.
# grep for 'PROFILE_CODE'
#

dt_call = time.perf_counter() - t0
pytensor.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call
self._restore_defaults()

if profile:
dt_call = time.perf_counter() - t0
pytensor.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call
profile.fct_callcount += 1
profile.fct_call_time += dt_call
if hasattr(self.vm, "update_profile"):
self.vm.update_profile(profile)
if profile.ignore_first_call:
profile.reset()
profile.ignore_first_call = False

if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
Expand Down Expand Up @@ -1572,6 +1508,8 @@
)
for i in self.inputs
]
if any(self.refeed):
warnings.warn("Inputs with default values are deprecated.", FutureWarning)

def create(self, input_storage=None, storage_map=None):
"""
Expand Down
10 changes: 5 additions & 5 deletions pytensor/compile/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ class In(SymbolicInput):
expression variable after each function call. If update is None, the
update will be the default value of the input.
mutable : bool
Defaults to False if update is None, True if update is not None.
True: permit the compiled function to modify the python object
being passed as the input.
False: do not permit the compiled function to modify the
python object being passed as the input.
Defaults to ``True`` if ``update`` is not ``None``, ``False`` otherwise.
When ``True``, permit the compiled function to modify the python object
being passed as the input to save memory. When an input is mutable,
it shouldn't be aliased (a view) of any other input. Otherwise,
behavior is undefined, and will likely yield wrong results.
borrow : bool
Default : take the same value as mutable.
True: permit the output of the compiled function to be aliased
Expand Down
3 changes: 0 additions & 3 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def fiter_variable(self, other):
" a symbolic placeholder."
)

def may_share_memory(a, b):
return False

def value_eq(a, b, force_same_dtype=True):
raise AssertionError(
"If you're assigning to a DisconnectedType you're"
Expand Down
Loading
Loading