Skip to content

Commit e7fbf51

Browse files
committed
[maskedtensor] Adagrad sparse semantics [3/4]
1 parent 04e1ba9 commit e7fbf51

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed
+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) Efficiently writing "sparse" semantics for Adagrad with MaskedTensor
5+
================================================================================
6+
"""
7+
8+
######################################################################
9+
# `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
10+
# that were introduced while writing "sparse" semantics for Adagrad.
11+
# But really, the code uses sparsity as a proxy for masked semantics rather than the intended use case of sparsity:
12+
# a compression and optimization technique,
13+
# Previously, we worked around the lack of formal masked semantics by introducing one-off semantics and operators
14+
# while forcing users to be aware of storage details such as indices and values.
15+
#
16+
# In particular, we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
17+
# and when it is just used to compress zeros.
18+
# We'll also compare and contrast this with equivalent code written using MaskedTensor.
19+
# In the end the code snippets are repeated without additional comments to show the difference in brevity.
20+
#
21+
# Preparations
22+
# ------------
23+
#
24+
25+
26+
import torch
27+
28+
# Some hyperparameters
29+
eps = 1e-10
30+
clr = 0.1
31+
32+
i = torch.tensor([[0, 1, 1], [2, 0, 2]])
33+
v = torch.tensor([3, 4, 5], dtype=torch.float32)
34+
grad = torch.sparse_coo_tensor(i, v, [2, 4])
35+
36+
######################################################################
37+
# Original sparse implementation
38+
# ------------------------------
39+
#
40+
# First, let's break down the current implementation of
41+
# `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
42+
# in PyTorch:
43+
#
44+
45+
def _make_sparse(grad, grad_indices, values):
46+
size = grad.size()
47+
if grad_indices.numel() == 0 or values.numel() == 0:
48+
return torch.empty_like(grad)
49+
return torch.sparse_coo_tensor(grad_indices, values, size)
50+
51+
# We don't support sparse gradients
52+
param = torch.arange(8).reshape(2, 4).float()
53+
state_sum = torch.full_like(param, 0.5) # initial value for state sum
54+
55+
grad = grad.coalesce() # the update is non-linear so indices must be unique
56+
grad_indices = grad._indices()
57+
grad_values = grad._values()
58+
# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
59+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
60+
61+
# We take care to make std sparse, even though state_sum clearly is not.
62+
# This means that we're only applying the gradient to parts of the state_sum
63+
# for which it is specified. This further drives the point home that the passed gradient is not sparse, but masked.
64+
std = state_sum.sparse_mask(grad)
65+
66+
# We currently dodge all these concerns using the private method values.
67+
std_values = std._values().sqrt_().add_(eps)
68+
69+
# We currently don't support div for sparse Tensors because zero / zero is
70+
# not well defined. For a MaskedTensor undefined / undefined is undefined.
71+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
72+
73+
######################################################################
74+
# `std = state_sum.sparse_mask(grad)` is where we have a very important divergence.
75+
#
76+
# The addition of eps should technically be applied to all values but instead is only applied to specified values.
77+
# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
78+
# If parts of the values of the gradient are zero, they are still included if materialized even though they
79+
# could be compressed by other sparse storage layouts.
80+
# This is technically quite brittle even though someone could argue that eps is always very small.
81+
#
82+
# Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme should cause densification,
83+
# but we force it not to for performance.
84+
# For this one-off case it is fine.. until we want to introduce new compression schemes
85+
# such as CSR, BSR, or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each
86+
# and write variations for gradients compressed using different storage formats, which is inconvenient.
87+
#
88+
# MaskedTensor sparse implementation
89+
# ----------------------------------
90+
#
91+
# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
92+
# MaskedTensor proposes to call the semantic extension through sparsity masked.
93+
# Currently we can't have dense semantics with sparse storage or masked semantics with dense storage;
94+
# MaskedTensor enables these ideas by purposefully separating the storage from the semantics.
95+
#
96+
# Consider the above example using a masked gradient:
97+
#
98+
99+
# Let's now import MaskedTensor!
100+
from torch.masked import masked_tensor
101+
102+
# Create an entirely new set of parameters to avoid errors
103+
param2 = torch.arange(8).reshape(2, 4).float()
104+
state_sum2 = torch.full_like(param, 0.5) # initial value for state sum
105+
106+
mask = (grad.to_dense() != 0).to_sparse()
107+
masked_grad = masked_tensor(grad, mask)
108+
109+
state_sum2 = state_sum2 + masked_grad.pow(2).get_data()
110+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
111+
112+
# We can add support for in-place operations later. Notice how this doesn't
113+
# need to access any storage internals and is in general a lot shorter
114+
std2 = std2.sqrt().add(eps)
115+
116+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
117+
118+
######################################################################
119+
# Note that the implementations look quite similar, but the MaskedTensor implementation is shorter and simpler.
120+
# For example, much of the boilerplate code around ``_make_sparse``
121+
# (and needing to have a separate implementation per layout) is handled for the user with :class:`MaskedTensor`.
122+
#
123+
# At this point, let's print both this version and original version for easier comparison:
124+
#
125+
126+
print("state_sum:\n", state_sum)
127+
print("state_sum2:\n", state_sum2)
128+
129+
print("std:\n", std)
130+
print("std2:\n", std2)
131+
132+
print("param:\n", param)
133+
print("param2:\n", param2)
134+
135+
######################################################################
136+
# which proves that the two implementations are indeed the same.
137+
#
138+
# Conclusion: Simpler Code with MaskedTensor
139+
# ------------------------------------------
140+
#
141+
# For reference, this is the regular, dense code path without masked gradients or sparsity:
142+
#
143+
144+
.. code-block:: python
145+
146+
state_sum.addcmul_(grad, grad, value=1)
147+
std = state_sum.sqrt().add_(eps)
148+
param.addcdiv_(grad, std, value=-clr)
149+
150+
The vanilla tensor implementation for sparse is:
151+
152+
.. code-block:: python
153+
154+
grad = grad.coalesce() # the update is non-linear so indices must be unique
155+
grad_indices = grad._indices()
156+
grad_values = grad._values()
157+
158+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout
159+
std = state_sum.sparse_mask(grad)
160+
std_values = std._values().sqrt_().add_(eps)
161+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
162+
163+
while :class:`MaskedTensor` minimizes the code to the snippet:
164+
165+
.. code-block:: python
166+
167+
state_sum2 = state_sum2 + masked_grad.pow(2).data()
168+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
169+
std2 = std2.sqrt().add(eps)
170+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
171+
172+
One major goal of :class:`MaskedTensor` is to enable sparsity semantics and applications, such as this one.
173+
To learn more about using sparsity, you can find
174+
[this MaskedTensor sparsity tutorial](https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html).
175+
Currently, COO and CSR sparse layouts are supported, though there are immediate plans to add more.

prototype_source/prototype_index.rst

+11
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor: Simplifying Adagrad Sparse Semantics
148+
:card_description: See a showcase on how masked tensors can enable sparse semantics and provide for a cleaner dev experience
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_adagrad.html
151+
:tags: MaskedTensor
152+
153+
144154
.. End of tutorial card section
145155
146156
.. raw:: html
@@ -172,3 +182,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172182
prototype/vmap_recipe.html
173183
prototype/vulkan_workflow.html
174184
prototype/nestedtensor.html
185+
prototype/maskedtensor_adagrad.html

0 commit comments

Comments
 (0)