-
Notifications
You must be signed in to change notification settings - Fork 129
Implement ScalarLoop in torch backend #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@Ch0ronomato thanks for taking a stab, I left some comments above |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #958 +/- ##
==========================================
+ Coverage 82.09% 82.10% +0.01%
==========================================
Files 183 185 +2
Lines 48010 48130 +120
Branches 8653 8669 +16
==========================================
+ Hits 39412 39519 +107
- Misses 6435 6444 +9
- Partials 2163 2167 +4
🚀 New features to boost your workflow:
|
carry = update(*carry, *constants) | ||
return torch.stack(carry) | ||
|
||
return torch.compiler.disable(scalar_loop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do recursive=False?
@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓 |
@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr? |
If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it's in the right direction, but need a bit more help to understand the new code if you can provide it :)
tests/link/pytorch/test_basic.py
Outdated
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") | ||
|
||
|
||
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.
tests/link/pytorch/test_basic.py
Outdated
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) | ||
|
||
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) | ||
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.
pytensor/link/pytorch/linker.py
Outdated
from torch import is_tensor | ||
|
||
if is_tensor(out): | ||
return out.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will probably create conflict when one of my other PRs gets merged as an FYI.
final_inputs[i] = list(layer) | ||
|
||
# make sure we still have the same number of things | ||
assert len(final_inputs) == len(shaped_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can put these into the unit test if that's preferred now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the assert is executed every time at runtime, yes let's not do it here
tests/link/pytorch/test_basic.py
Outdated
torch.zeros(*input_shapes[-1]) | ||
] | ||
mock_inner_func = MagicMock() | ||
ret_value = torch.rand(2, 2).unbind(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to expected
tests/link/pytorch/test_basic.py
Outdated
mock_inner_func.f.return_value = ret_value | ||
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) | ||
result = elemwise_fn(*args) | ||
for actual, expected in zip(ret_value, result): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are backwards fyi
def elemwise_scalar_loop(base_fn, op, node, **kwargs): | ||
""" | ||
ScalarLoop + Elemwise is too common | ||
to not work, but @1031, vmap won't allow it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include full link instead of @1031
Elemwise._check_runtime_broadcast(node, inputs) | ||
shaped_inputs = torch.broadcast_tensors(*inputs) | ||
expected_size = shaped_inputs[0].numel() | ||
final_inputs = [s.clone() for s in shaped_inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why .clone()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be unnecessary now. We need the original number of dimensions for the outer loop. I could just grab that count instead.
for _ in range(shaped_inputs[0].dim() - 1): | ||
for i, _ in enumerate(shaped_inputs): | ||
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) | ||
final_inputs[i] = list(layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is more performant? Doing this nesting, or raveling all the inputs after broadcasting and doing a single unbind loop?
Either way, doesn't avoid the explicit broadcasting copy or does it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhhhh, this is basically like ravel you're right!
According to the torch docs, ravel only copies if needed. So there maybe cases where no coping happens
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) | ||
res = [base_fn(*args) for args in zip(*final_inputs)] | ||
|
||
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this reintroduce the original shape? Say if the Elemwise of the Scalar Loop had output shape == (5, 3, 2) ?
Co-authored-by: Ricardo Vieira <[email protected]>
0905bec
to
46e3e72
Compare
|
||
out_shape = bcasted_inputs[0].size() | ||
out_size = out_shape.numel() | ||
raveled_outputs = [torch.zeros(out_size) for out in node.outputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there no torch.empty
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mb; i had an old version of torch on my machine (2.2) which didn't have it, but 2.3+ does. Reverted to torch.empty
pytensor/link/pytorch/linker.py
Outdated
@@ -77,11 +90,11 @@ def __del__(self): | |||
self.gen_functors = [] | |||
|
|||
# Torch does not accept numpy inputs and may return GPU objects | |||
def fn(*inputs, inner_fn=inner_fn): | |||
def create_outputs(*inputs, inner_fn=inner_fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the new name? Seems less clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this fn
was shadowing a local variable fn
so i just renamed one of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but can we use a different name. This doesn't "create_outputs" it converts inputs to torch tensors and outputs back to pytensor-compatible types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure thing - I can also just keep the shadowing lol. It's not the end of the world.
From your description I would probably have called it convert_types
or smth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also put it inside the wrapper __call__
I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that makes sense too.
2f70694
to
521ad67
Compare
Description
Adds
ScalarLoop
andElemwise(ScalarLoop(...))
support for pytorch. Due to #1031 , we have to have a fall back method for things that do what torch calls "data-dependent logic".Related Issue
Checklist
Type of change