Skip to content

Commit f5b048f

Browse files
committed
[maskedtensor] Adagrad sparse semantics [3/4]
1 parent 04e1ba9 commit f5b048f

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed
+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
Efficiency of writing "sparse" semantics for Adagrad
2+
====================================================
3+
4+
`Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
5+
that were introduce while writing "sparse" semantics for Adagrad.
6+
But really the code doesn't use sparsity as a compression and optimization technique,
7+
it wants to use masked semantics. We worked around this by introducing one-off semantics and operators
8+
that encode this behavior while forcing users to be aware of storage details such as indices and values.
9+
10+
In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
11+
and when it is just used to compress zeros.
12+
We'll also compare and contrast this with equivalent code written using MaskedTensor.
13+
In the end the code snippets are repeat without additional comments to show the difference in brevity.
14+
15+
Set up
16+
------
17+
18+
.. code-block:: python
19+
20+
import torch
21+
from torch.masked.maskedtensor import masked_tensor
22+
23+
# Some hyperparameters
24+
eps = 1e-10
25+
clr = 0.1
26+
27+
i = torch.tensor([[0, 1, 1], [2, 0, 2]])
28+
v = torch.tensor([3, 4, 5], dtype=torch.float32)
29+
grad = torch.sparse_coo_tensor(i, v, [2, 4])
30+
31+
Original sparse implementation
32+
------------------------------
33+
34+
First, let's break down the current implementation of
35+
`Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
36+
in PyTorch currently:
37+
38+
.. code-block:: python
39+
:linenos:
40+
41+
def _make_sparse(grad, grad_indices, values):
42+
size = grad.size()
43+
if grad_indices.numel() == 0 or values.numel() == 0:
44+
return torch.empty_like(grad)
45+
return torch.sparse_coo_tensor(grad_indices, values, size)
46+
47+
# We don't support sparse gradients
48+
param = torch.arange(8).reshape(2, 4).float()
49+
state_sum = torch.full_like(param, 0.5) # initial value for state sum
50+
51+
grad = grad.coalesce() # the update is non-linear so indices must be unique
52+
grad_indices = grad._indices()
53+
grad_values = grad._values()
54+
# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
55+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
56+
57+
# We take care to make std sparse, even though state_sum clearly is not.
58+
# This means that we're only applying the gradient to parts of the state_sum
59+
# for which it is specified. This even drives the point home a lot more that
60+
# the passed gradient is not sparse, but masked.
61+
std = state_sum.sparse_mask(grad)
62+
63+
# We currently dodge all these concerns using the private method values.
64+
std_values = std._values().sqrt_().add_(eps)
65+
66+
# We currently don't support div for sparse Tensors because zero / zero is
67+
# not well defined. For a MaskedTensor undefined / undefined is undefined.
68+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
69+
70+
Line 24 is where we have a very important divergence.
71+
72+
The addition of eps should technically be applied to all values but instead is only applied to specified values.
73+
Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
74+
If parts of the values of the gradient are zero, they are still included if materialized even though they
75+
could be compressed by other sparse storage layouts.
76+
This is technically quite brittle even though someone could argue that eps is always very small.
77+
78+
Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme should cause densification,
79+
but we force it not to for performance.
80+
For this one-off case it is fine.. until we want to introduce new compression schemes
81+
such as CSR, BSR, or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each
82+
and write variations for gradients compressed using different storage formats, which is inconvenient.
83+
84+
MaskedTensor sparse implementation
85+
----------------------------------
86+
87+
We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
88+
MaskedTensor proposes to call the semantic extension through sparsity masked.
89+
Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while
90+
MaskedTensor fixes that because it separates the storage from the semantics.
91+
92+
Consider the above example using a masked gradient:
93+
94+
.. code-block:: python
95+
96+
# Create an entirely new set of parameters to avoid errors
97+
param2 = torch.arange(8).reshape(2, 4).float()
98+
state_sum2 = torch.full_like(param, 0.5) # initial value for state sum
99+
100+
mask = (grad.to_dense() != 0).to_sparse()
101+
masked_grad = masked_tensor(grad, mask)
102+
103+
state_sum2 = state_sum2 + masked_grad.pow(2).get_data()
104+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
105+
106+
# We can add support for in-place operations later. Notice how this doesn't
107+
# need to access any storage internals and is in general a lot shorter
108+
std2 = std2.sqrt().add(eps)
109+
110+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
111+
112+
Note that the implementations look quite similar but just shorter. Much of the boilerplate code
113+
around `_make_sparse` (and needing to have a separate implementation per layout) is handled
114+
for you with :class:`MaskedTensor`.
115+
116+
At this point, let's print both this version and original version for easier comparison:
117+
118+
>>> state_sum
119+
>>> state_sum
120+
tensor([[ 0.5000, 0.5000, 9.5000, 0.5000],
121+
[16.5000, 0.5000, 25.5000, 0.5000]])
122+
>>> state_sum2
123+
tensor([[ 0.5000, 0.5000, 9.5000, 0.5000],
124+
[16.5000, 0.5000, 25.5000, 0.5000]])
125+
>>> std
126+
tensor(indices=tensor([[0, 1, 1],
127+
[2, 0, 2]]),
128+
values=tensor([3.0822, 4.0620, 5.0498]),
129+
size=(2, 4), nnz=3, layout=torch.sparse_coo)
130+
>>> std2
131+
MaskedTensor(
132+
[
133+
[ --, --, 3.0822, --],
134+
[ 4.0620, --, 5.0498, --]
135+
]
136+
)
137+
>>> param
138+
tensor([[0.0000, 1.0000, 1.9027, 3.0000],
139+
[3.9015, 5.0000, 5.9010, 7.0000]])
140+
>>> param2
141+
tensor([[0.0000, 1.0000, 1.9027, 3.0000],
142+
[3.9015, 5.0000, 5.9010, 7.0000]])
143+
144+
which proves that the two implementations are indeed the same.
145+
146+
Conclusion: Difference in Code
147+
------------------------------
148+
149+
For reference, this is the regular, dense code path without masked gradients or sparsity:
150+
151+
.. code-block:: python
152+
153+
state_sum.addcmul_(grad, grad, value=1)
154+
std = state_sum.sqrt().add_(eps)
155+
param.addcdiv_(grad, std, value=-clr)
156+
157+
The vanilla tensor implementation for sparse is:
158+
159+
.. code-block:: python
160+
161+
grad = grad.coalesce() # the update is non-linear so indices must be unique
162+
grad_indices = grad._indices()
163+
grad_values = grad._values()
164+
165+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout
166+
std = state_sum.sparse_mask(grad)
167+
std_values = std._values().sqrt_().add_(eps)
168+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
169+
170+
while :class:`MaskedTensor` minimizes the code to the snippet:
171+
172+
.. code-block:: python
173+
174+
state_sum2 = state_sum2 + masked_grad.pow(2).data()
175+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
176+
std2 = std2.sqrt().add(eps)
177+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
178+
179+
One major goal of :class:`MaskedTensor` is to enable sparsity semantics and applications, such as this one.
180+
To learn more about using sparsity, you can find
181+
[this MaskedTensor sparsity tutorial](https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html).
182+
Currently, COO and CSR sparse layouts are supported, though there are immediate plans to add more.

prototype_source/prototype_index.rst

+11
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor Adagrad Sparse Semantics
148+
:card_description: See a showcase on how masked tensors can enable sparse semantics and provide for a cleaner dev experience
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_adagrad.html
151+
:tags: MaskedTensor
152+
153+
144154
.. End of tutorial card section
145155
146156
.. raw:: html
@@ -172,3 +182,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172182
prototype/vmap_recipe.html
173183
prototype/vulkan_workflow.html
174184
prototype/nestedtensor.html
185+
prototype/maskedtensor_adagrad.html

0 commit comments

Comments
 (0)