-
Notifications
You must be signed in to change notification settings - Fork 130
[WIP] Allow function dispatch for constants #1159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,9 @@ def __init__(self, *args, **kwargs): | |
self.gen_functors = [] | ||
|
||
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): | ||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch import pytorch_funcify, pytorch_typify | ||
|
||
# We want to have globally unique names | ||
# across the entire pytensor graph, not | ||
|
@@ -25,9 +27,21 @@ def conversion_func_register(*args, **kwargs): | |
self.gen_functors.append((f"_{name}", functor)) | ||
return functor | ||
|
||
def constants_wrapper(x, **kwargs): | ||
x = pytorch_typify(x) | ||
|
||
@torch.compiler.assume_constant_result | ||
def torch_assume_constant(arg=x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is causing some strange behavior in the compiler where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's possible this approach has issues since we don't wrap things in a nn.Module pytorch/pytorch#124858 (we shouldn't need to.) |
||
return arg | ||
|
||
name = kwargs["unique_name"](torch_assume_constant) | ||
self.gen_functors.append((f"_{name}", torch_assume_constant)) | ||
return torch_assume_constant | ||
|
||
built_kwargs = { | ||
"unique_name": generator, | ||
"conversion_func": conversion_func_register, | ||
"type_conversion_fn": constants_wrapper, | ||
**kwargs, | ||
} | ||
return pytorch_funcify( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -749,11 +749,25 @@ def fgraph_to_python( | |
) | ||
if input_storage[0] is not None or isinstance(i, Constant): | ||
# Constants need to be assigned locally and referenced | ||
global_env[local_input_name] = type_conversion_fn( | ||
getter_or_value = type_conversion_fn( | ||
input_storage[0], variable=i, storage=input_storage, **kwargs | ||
) | ||
# TODO: We could attempt to use the storage arrays directly | ||
# E.g. `local_input_name = f"{local_input_name}[0]"` | ||
if callable(getter_or_value): | ||
# we got passed a function, this could be used to indicate something | ||
# to the backend. We'll embed it | ||
new_output_name = unique_name(i) | ||
getter_unique_name = unique_name(getter_or_value) | ||
global_env[getter_unique_name] = getter_or_value | ||
assign_str = f"{new_output_name} = {getter_unique_name}()" | ||
body_assigns.append(assign_str) | ||
node_input_names.append(new_output_name) | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully we don't need this and it was an over optimization. Try to refactor away the continue. |
||
else: | ||
global_env[local_input_name] = type_conversion_fn( | ||
input_storage[0], variable=i, storage=input_storage, **kwargs | ||
) | ||
# TODO: We could attempt to use the storage arrays directly | ||
# E.g. `local_input_name = f"{local_input_name}[0]"` | ||
node_input_names.append(local_input_name) | ||
|
||
node_output_names = [unique_name(v) for v in node.outputs] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say more?
The torch compiler threw asserts when a zero dim np value was passed back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it going to upcast values to float64, whatever python integers are? Does torch have scalars (not tensors) with specific dtypes we can use instead?