Skip to content

[WIP] Allow function dispatch for constants #1159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):

@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@pytorch_typify.register(np.number)
def pytorch_typify_extract(data, **kwargs):
return data.item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say more?

The torch compiler threw asserts when a zero dim np value was passed back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it going to upcast values to float64, whatever python integers are? Does torch have scalars (not tensors) with specific dtypes we can use instead?



@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
Expand All @@ -57,11 +61,13 @@ def pytorch_funcify_FunctionGraph(
conversion_func=pytorch_funcify,
**kwargs,
):
if "type_conversion_fn" not in kwargs:
kwargs["type_conversion_fn"] = pytorch_typify

built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**built_kwargs,
)
Expand Down
16 changes: 15 additions & 1 deletion pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def __init__(self, *args, **kwargs):
self.gen_functors = []

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
import torch

from pytensor.link.pytorch.dispatch import pytorch_funcify, pytorch_typify

# We want to have globally unique names
# across the entire pytensor graph, not
Expand All @@ -25,9 +27,21 @@ def conversion_func_register(*args, **kwargs):
self.gen_functors.append((f"_{name}", functor))
return functor

def constants_wrapper(x, **kwargs):
x = pytorch_typify(x)

@torch.compiler.assume_constant_result
def torch_assume_constant(arg=x):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing some strange behavior in the compiler where outs is a bunch of functions (the constants specifically) which doesn't make a lot of sense to me. I'm still investigating the cause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible this approach has issues since we don't wrap things in a nn.Module pytorch/pytorch#124858 (we shouldn't need to.)

return arg

name = kwargs["unique_name"](torch_assume_constant)
self.gen_functors.append((f"_{name}", torch_assume_constant))
return torch_assume_constant

built_kwargs = {
"unique_name": generator,
"conversion_func": conversion_func_register,
"type_conversion_fn": constants_wrapper,
**kwargs,
}
return pytorch_funcify(
Expand Down
20 changes: 17 additions & 3 deletions pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,11 +749,25 @@ def fgraph_to_python(
)
if input_storage[0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
getter_or_value = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
if callable(getter_or_value):
# we got passed a function, this could be used to indicate something
# to the backend. We'll embed it
new_output_name = unique_name(i)
getter_unique_name = unique_name(getter_or_value)
global_env[getter_unique_name] = getter_or_value
assign_str = f"{new_output_name} = {getter_unique_name}()"
body_assigns.append(assign_str)
node_input_names.append(new_output_name)
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully we don't need this and it was an over optimization. Try to refactor away the continue.

else:
global_env[local_input_name] = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)

node_output_names = [unique_name(v) for v in node.outputs]
Expand Down
Loading