Skip to content

Implement ScalarLoop in torch backend #958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Dec 8, 2024

Conversation

Ch0ronomato
Copy link
Contributor

@Ch0ronomato Ch0ronomato commented Aug 1, 2024

Description

Adds ScalarLoop and Elemwise(ScalarLoop(...)) support for pytorch. Due to #1031 , we have to have a fall back method for things that do what torch calls "data-dependent logic".

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Aug 3, 2024
@ricardoV94
Copy link
Member

@Ch0ronomato thanks for taking a stab, I left some comments above

@Ch0ronomato Ch0ronomato requested a review from ricardoV94 August 11, 2024 18:56
Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 81.25000% with 9 lines in your changes missing coverage. Please review.

Project coverage is 82.10%. Comparing base (ef97287) to head (521ad67).
Report is 131 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/scalar.py 77.27% 4 Missing and 1 partial ⚠️
pytensor/link/pytorch/dispatch/elemwise.py 81.81% 2 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #958      +/-   ##
==========================================
+ Coverage   82.09%   82.10%   +0.01%     
==========================================
  Files         183      185       +2     
  Lines       48010    48130     +120     
  Branches     8653     8669      +16     
==========================================
+ Hits        39412    39519     +107     
- Misses       6435     6444       +9     
- Partials     2163     2167       +4     
Files with missing lines Coverage Δ
pytensor/link/pytorch/linker.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/elemwise.py 69.11% <81.81%> (+2.45%) ⬆️
pytensor/link/pytorch/dispatch/scalar.py 74.07% <77.27%> (+2.19%) ⬆️

... and 6 files with indirect coverage changes

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 changed the title Add torch scalar loop Implement ScalarLoop in torch backend Sep 1, 2024
carry = update(*carry, *constants)
return torch.stack(carry)

return torch.compiler.disable(scalar_loop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do recursive=False?

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

@ricardoV94
Copy link
Member

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it's in the right direction, but need a bit more help to understand the new code if you can provide it :)

torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")


@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.

np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))

expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.

from torch import is_tensor

if is_tensor(out):
return out.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably create conflict when one of my other PRs gets merged as an FYI.

final_inputs[i] = list(layer)

# make sure we still have the same number of things
assert len(final_inputs) == len(shaped_inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can put these into the unit test if that's preferred now.

Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the assert is executed every time at runtime, yes let's not do it here

torch.zeros(*input_shapes[-1])
]
mock_inner_func = MagicMock()
ret_value = torch.rand(2, 2).unbind(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to expected

mock_inner_func.f.return_value = ret_value
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
result = elemwise_fn(*args)
for actual, expected in zip(ret_value, result):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are backwards fyi

def elemwise_scalar_loop(base_fn, op, node, **kwargs):
"""
ScalarLoop + Elemwise is too common
to not work, but @1031, vmap won't allow it.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include full link instead of @1031

Elemwise._check_runtime_broadcast(node, inputs)
shaped_inputs = torch.broadcast_tensors(*inputs)
expected_size = shaped_inputs[0].numel()
final_inputs = [s.clone() for s in shaped_inputs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why .clone()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be unnecessary now. We need the original number of dimensions for the outer loop. I could just grab that count instead.

Comment on lines 193 to 196
for _ in range(shaped_inputs[0].dim() - 1):
for i, _ in enumerate(shaped_inputs):
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
final_inputs[i] = list(layer)
Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is more performant? Doing this nesting, or raveling all the inputs after broadcasting and doing a single unbind loop?

Either way, doesn't avoid the explicit broadcasting copy or does it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhhhh, this is basically like ravel you're right!

According to the torch docs, ravel only copies if needed. So there maybe cases where no coping happens

assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
res = [base_fn(*args) for args in zip(*final_inputs)]

return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this reintroduce the original shape? Say if the Elemwise of the Scalar Loop had output shape == (5, 3, 2) ?


out_shape = bcasted_inputs[0].size()
out_size = out_shape.numel()
raveled_outputs = [torch.zeros(out_size) for out in node.outputs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no torch.empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mb; i had an old version of torch on my machine (2.2) which didn't have it, but 2.3+ does. Reverted to torch.empty

@@ -77,11 +90,11 @@ def __del__(self):
self.gen_functors = []

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
def create_outputs(*inputs, inner_fn=inner_fn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the new name? Seems less clear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this fn was shadowing a local variable fn so i just renamed one of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but can we use a different name. This doesn't "create_outputs" it converts inputs to torch tensors and outputs back to pytensor-compatible types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing - I can also just keep the shadowing lol. It's not the end of the world.

From your description I would probably have called it convert_types or smth.

Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also put it inside the wrapper __call__ I guess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that makes sense too.

@Ch0ronomato Ch0ronomato merged commit 9858b33 into pymc-devs:main Dec 8, 2024
60 of 62 checks passed
@Ch0ronomato Ch0ronomato deleted the scalarloop branch December 8, 2024 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pytorch vmap limitation
2 participants