Skip to content

Commit f141231

Browse files
committed
tutorial fixes!
1 parent 12ea814 commit f141231

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed
+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) Efficiently writing "sparse" semantics for Adagrad with MaskedTensor
5+
================================================================================
6+
"""
7+
8+
######################################################################
9+
# Before working through this tutorial, please review the MaskedTensor
10+
# `Overview <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`__ and
11+
# `Sparsity <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__ tutorials.
12+
#
13+
# Introduction and Motivation
14+
# ---------------------------
15+
# `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
16+
# that were introduced while writing "sparse" semantics for Adagrad.
17+
# But really, the code uses sparsity as a proxy for masked semantics rather than the intended use case of sparsity:
18+
# a compression and optimization technique,
19+
# Previously, we worked around the lack of formal masked semantics by introducing one-off semantics and operators
20+
# while forcing users to be aware of storage details such as indices and values.
21+
#
22+
# In particular, we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
23+
# and when it is just used to compress zeros.
24+
# We'll also compare and contrast this with equivalent code written using MaskedTensor.
25+
# In the end the code snippets are repeated without additional comments to show the difference in brevity.
26+
#
27+
# Preparation
28+
# -----------
29+
#
30+
31+
import torch
32+
33+
# Some hyperparameters
34+
eps = 1e-10
35+
clr = 0.1
36+
37+
i = torch.tensor([[0, 1, 1], [2, 0, 2]])
38+
v = torch.tensor([3, 4, 5], dtype=torch.float32)
39+
grad = torch.sparse_coo_tensor(i, v, [2, 4])
40+
41+
######################################################################
42+
# Simpler Code with MaskedTensor
43+
# ------------------------------
44+
#
45+
# Before we get too far in the weeds, let's introduce the problem a bit more concretely. We will be taking a look
46+
# into the `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
47+
# implementation in PyTorch with the ultimate goal of simplifying and more faithfully representing the sparse approach.
48+
#
49+
# For reference, this is the regular, dense code path without masked gradients or sparsity:
50+
#
51+
# .. code-block:: python
52+
#
53+
# state_sum.addcmul_(grad, grad, value=1)
54+
# std = state_sum.sqrt().add_(eps)
55+
# param.addcdiv_(grad, std, value=-clr)
56+
#
57+
# The vanilla tensor implementation for sparse is:
58+
#
59+
# .. code-block:: python
60+
#
61+
# def _make_sparse(grad, grad_indices, values):
62+
# size = grad.size()
63+
# if grad_indices.numel() == 0 or values.numel() == 0:
64+
# return torch.empty_like(grad)
65+
# return torch.sparse_coo_tensor(grad_indices, values, size)
66+
#
67+
# grad = grad.coalesce() # the update is non-linear so indices must be unique
68+
# grad_indices = grad._indices()
69+
# grad_values = grad._values()
70+
#
71+
# state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout
72+
# std = state_sum.sparse_mask(grad)
73+
# std_values = std._values().sqrt_().add_(eps)
74+
# param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
75+
#
76+
# while :class:`MaskedTensor` minimizes the code to the snippet:
77+
#
78+
# .. code-block:: python
79+
#
80+
# state_sum2 = state_sum2 + masked_grad.pow(2).data()
81+
# std2 = masked_tensor(state_sum2.to_sparse(), mask)
82+
# std2 = std2.sqrt().add(eps)
83+
# param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
84+
#
85+
# In this tutorial, we will go through each implementation line by line, but at first glance, we can notice
86+
# (1) how much shorter the MaskedTensor implementation is, (2) it avoids conversions between dense and sparse tensors.
87+
#
88+
89+
######################################################################
90+
# Original Sparse Implementation
91+
# ------------------------------
92+
#
93+
# Now, let's break down the code with some inline comments:
94+
#
95+
96+
def _make_sparse(grad, grad_indices, values):
97+
size = grad.size()
98+
if grad_indices.numel() == 0 or values.numel() == 0:
99+
return torch.empty_like(grad)
100+
return torch.sparse_coo_tensor(grad_indices, values, size)
101+
102+
# We don't support sparse gradients
103+
param = torch.arange(8).reshape(2, 4).float()
104+
state_sum = torch.full_like(param, 0.5) # initial value for state sum
105+
106+
grad = grad.coalesce() # the update is non-linear so indices must be unique
107+
grad_indices = grad._indices()
108+
grad_values = grad._values()
109+
# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
110+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
111+
112+
# We take care to make std sparse, even though state_sum clearly is not.
113+
# This means that we're only applying the gradient to parts of the state_sum
114+
# for which it is specified. This further drives the point home that the passed gradient is not sparse, but masked.
115+
# We currently dodge all these concerns using the private method `_values`.
116+
std = state_sum.sparse_mask(grad)
117+
std_values = std._values().sqrt_().add_(eps)
118+
119+
# Note here that we currently don't support div for sparse Tensors because zero / zero is not well defined,
120+
# so we're forced to perform `grad_values / std_values` outside the sparse semantic and then convert back to a
121+
# sparse tensor with `make_sparse`.
122+
# We'll later see that MaskedTensor will actually handle these operations for us as well as properly denote
123+
# undefined / undefined = undefined!
124+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
125+
126+
######################################################################
127+
# The third to last line -- `std = state_sum.sparse_mask(grad)` -- is where we have a very important divergence.
128+
#
129+
# The addition of eps should technically be applied to all values but instead is only applied to specified values.
130+
# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
131+
# If parts of the values of the gradient are zero, they are still included if materialized even though they
132+
# could be compressed by other sparse storage layouts. This is theoretically quite brittle!
133+
# That said, one could argue that eps is always very small, so it might not matter so much in practice.
134+
#
135+
# Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme
136+
# should cause densification, but we force it not to for performance.
137+
# For this one-off case it is fine.. until we want to introduce new compression scheme, such as
138+
# `CSC <https://pytorch.org/docs/master/sparse.html#sparse-csc-docs`__,
139+
# `BSR <https://pytorch.org/docs/master/sparse.html#sparse-bsr-docs>`__,
140+
# or `BSC <https://pytorch.org/docs/master/sparse.html#sparse-bsc-docs>`__.
141+
# We will then need to introduce separate Tensor types for each and write variations for gradients compressed
142+
# using different storage formats, which is inconvenient and not quite scalable nor clean.
143+
#
144+
# MaskedTensor Sparse Implementation
145+
# ----------------------------------
146+
#
147+
# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
148+
# MaskedTensor proposes to disentangle the sparsity optimization from the semantic extension; for example,
149+
# currently we can't have dense semantics with sparse storage or masked semantics with dense storage.
150+
# MaskedTensor enables these ideas by purposefully separating the storage from the semantics.
151+
#
152+
# Consider the above example using a masked gradient:
153+
#
154+
155+
# Let's now import MaskedTensor!
156+
from torch.masked import masked_tensor
157+
158+
# Create an entirely new set of parameters to avoid errors
159+
param2 = torch.arange(8).reshape(2, 4).float()
160+
state_sum2 = torch.full_like(param, 0.5) # initial value for state sum
161+
162+
mask = (grad.to_dense() != 0).to_sparse()
163+
masked_grad = masked_tensor(grad, mask)
164+
165+
state_sum2 = state_sum2 + masked_grad.pow(2).get_data()
166+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
167+
168+
# We can add support for in-place operations later. Notice how this doesn't
169+
# need to access any storage internals and is in general a lot shorter
170+
std2 = std2.sqrt().add(eps)
171+
172+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
173+
174+
######################################################################
175+
# Note that the implementations look quite similar, but the MaskedTensor implementation is shorter and simpler.
176+
# In particular, much of the boilerplate code around ``_make_sparse``
177+
# (and needing to have a separate implementation per layout) is handled for the user with :class:`MaskedTensor`.
178+
#
179+
# At this point, let's print both this version and original version for easier comparison:
180+
#
181+
182+
print("state_sum:\n", state_sum)
183+
print("state_sum2:\n", state_sum2)
184+
185+
print("std:\n", std)
186+
print("std2:\n", std2)
187+
188+
print("param:\n", param)
189+
print("param2:\n", param2)
190+
191+
######################################################################
192+
# Conclusion
193+
# ----------
194+
#
195+
# In this tutorial, we've discussed how native masked semantics can enable a cleaner developer experience for
196+
# Adagrad's existing implementation in PyTorch, which used sparsity as a proxy for writing masked semantics.
197+
# But more importantly, allowing masked semantics to be a first class citizen through MaskedTensor
198+
# removes the reliance on sparsity or unreliable hacks to mimic masking, thereby allowing for proper independence
199+
# and development, while enabling sparse semantics, such as this one.
200+
#
201+
# Further Reading
202+
# ---------------
203+
#
204+
# To continue learning more, you can find our final review (for now) on
205+
# `MaskedTensor Advanced Semantics <https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics.html>`__
206+
# to see some of the differences in design decisions between :class:`MaskedTensor` and NumPy's MaskedArray, as well
207+
# as reduction semantics.
208+
#

prototype_source/prototype_index.rst

+11
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor: Simplifying Adagrad Sparse Semantics
148+
:card_description: See a showcase on how masked tensors can enable sparse semantics and provide for a cleaner dev experience
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_adagrad.html
151+
:tags: MaskedTensor
152+
153+
144154
.. End of tutorial card section
145155
146156
.. raw:: html
@@ -172,3 +182,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172182
prototype/vmap_recipe.html
173183
prototype/vulkan_workflow.html
174184
prototype/nestedtensor.html
185+
prototype/maskedtensor_adagrad.html

0 commit comments

Comments
 (0)