Skip to content

Commit 2bf9734

Browse files
committed
[maskedtensor] Add adagrad sparse semantics tutorial
ghstack-source-id: 851625e Pull Request resolved: #2047
1 parent 1a0de45 commit 2bf9734

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Efficiency of writing "sparse" semantics for Adagrad
5+
====================================================
6+
7+
`Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
8+
that were introduce while writing "sparse" semantics for Adagrad.
9+
But really the code doesn't use sparsity as a compression and optimization technique,
10+
it wants to use masked semantics. We worked around this by introducing one-off semantics and operators
11+
that encode this behavior while forcing users to be aware of storage details such as indices and values.
12+
13+
In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
14+
and when it is just used to compress zeros.
15+
We'll also compare and contrast this with equivalent code written using MaskedTensor.
16+
In the end the code snippets are repeat without additional comments to show the difference in brevity.
17+
18+
""""
19+
20+
import torch
21+
from torch.masked.maskedtensor import masked_tensor
22+
23+
######################################################################
24+
# Original sparse implementation
25+
# ------------------------------
26+
#
27+
# First, let's look at the current implementation of
28+
# `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
29+
#
30+
31+
def _make_sparse(grad, grad_indices, values):
32+
size = grad.size()
33+
if grad_indices.numel() == 0 or values.numel() == 0:
34+
return torch.empty_like(grad)
35+
return torch.sparse_coo_tensor(grad_indices, values, size)
36+
37+
# Some hyperparameters
38+
eps = 1e-10
39+
clr = 0.1
40+
41+
# We don't support sparse gradients
42+
param = torch.arange(8).reshape(2, 4).float()
43+
i = torch.tensor([[0, 1, 1],
44+
[2, 0, 2]])
45+
v = torch.tensor([3, 4, 5], dtype=torch.float32)
46+
grad = torch.sparse_coo_tensor(i, v, [2, 4])
47+
state_sum = torch.full_like(param, 0.5) # initial value for state sum
48+
49+
print("param:\n", param)
50+
print("grad:\n", grad.to_dense())
51+
print("state_sum:\n", state_sum)
52+
53+
######################################################################
54+
#
55+
56+
state_sum = torch.full_like(param, 0.5) # initial value for state sum
57+
print(state_sum)
58+
59+
grad = grad.coalesce() # the update is non-linear so indices must be unique
60+
grad_indices = grad._indices()
61+
grad_values = grad._values()
62+
63+
# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
64+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
65+
# We take care to make std sparse, even though state_sum clearly is not.
66+
# This means that we're only applying the gradient to parts of the state_sum
67+
# for which it is specified. This even drives the point home a lot more that
68+
# the passed gradient is not sparse, but masked.
69+
std = state_sum.sparse_mask(grad)
70+
print("state_sum:\n", state_sum)
71+
print("std:\n", std.to_dense())
72+
73+
######################################################################
74+
# This is where we have a very important divergence.
75+
# The addition of eps should technically be applied to all values, but instead is only applied to specified values.
76+
# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
77+
# If parts of the values of the gradient are zero they are still included if materialized.
78+
# Even though they could be compressed by other sparse storage layouts.
79+
# This is technically quite brittle even though someone could argue that eps is always very small.
80+
#
81+
# Moreover, an implementation add_ for sparsity as a storage layout and compression scheme should cause densification,
82+
# but we force it not to.
83+
# For this one-off case it is fine until we want to introduce new compression schemes
84+
# such as CSR, BSR or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each
85+
# and write variations for gradients compressed using different storage formats.
86+
#
87+
88+
# We currently dodge all these concerns using the private method values.
89+
std_values = std._values().sqrt_().add_(eps)
90+
91+
# We currently don't support div for sparse Tensors because zero / zero is
92+
# not well defined. For a MaskedTensor undefined / undefined is undefined.
93+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
94+
print("param:\n", param)
95+
96+
######################################################################
97+
# MaskedTensor sparse implementation
98+
# ----------------------------------
99+
#
100+
# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
101+
# MaskedTensor proposes to call the semantic extension through sparsity masked.
102+
# Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while
103+
# MaskedTensor fixes that because it separates the storage from the semantics.
104+
# Consider the above example using a masked gradient:
105+
#
106+
107+
# Create an entirely new set of parameters to avoid errors
108+
param2 = torch.arange(8).reshape(2, 4).float()
109+
state_sum2 = torch.full_like(param, 0.5) # initial value for state sum
110+
111+
mask = (grad.to_dense() != 0).to_sparse()
112+
masked_grad = masked_tensor(grad, mask)
113+
print("masked_grad:\n", masked_grad)
114+
115+
######################################################################
116+
#
117+
118+
state_sum2 = state_sum2 + masked_grad.pow(2).data()
119+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
120+
121+
# Let's print both this version and the regular version for easier comparison
122+
print("state_sum:\n", state_sum)
123+
print("std:\n", std)
124+
print("state_sum2:\n", state_sum2)
125+
print("std2:\n", std2)
126+
127+
######################################################################
128+
#
129+
130+
# We can add support for in-place operations later. Notice how this doesn't
131+
# need to access any storage internals and is in general a lot shorter
132+
std2 = std2.sqrt().add(eps)
133+
134+
print("std:\n", std)
135+
print("std2:\n", std2)
136+
137+
# .data() indeed returns a sparse tensor
138+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
139+
print("param2:\n", param2)
140+
141+
######################################################################
142+
# Conclusion: Difference in code
143+
# ------------------------------
144+
#
145+
# For reference, this is the regular, dense code path without masked gradients or sparsity:
146+
# ::
147+
#
148+
# state_sum.addcmul_(grad, grad, value=1)
149+
# std = state_sum.sqrt().add_(eps)
150+
# param.addcdiv_(grad, std, value=-clr)
151+
#
152+
# The vanilla tensor implementation for sparse is:
153+
#
154+
155+
grad = grad.coalesce() # the update is non-linear so indices must be unique
156+
grad_indices = grad._indices()
157+
grad_values = grad._values()
158+
size = grad.size()
159+
160+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
161+
std = state_sum.sparse_mask(grad)
162+
std_values = std._values().sqrt_().add_(eps)
163+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
164+
165+
######################################################################
166+
# while MaskedTensor minimizes the code to the snippet:
167+
#
168+
169+
state_sum2 = state_sum2 + masked_grad.pow(2).data()
170+
std2 = masked_tensor(state_sum2.to_sparse(), mask)
171+
std2 = std2.sqrt().add(eps)
172+
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
173+
174+
######################################################################
175+
# And for good measure, let's make sure the parameters match:
176+
#
177+
178+
print("param:\n", param)
179+
print("param2:\n", param2)
180+

index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,8 @@ Additional Resources
814814
beginner/maskedtensor_sparsity
815815
beginner/maskedtensor_distinguish_gradient
816816
beginner/maskedtensor_safe_softmax
817+
beginner/maskedtensor_missing_nan_ops
818+
beginner/maskedtensor_adagrad_sparse_semantics
817819

818820

819821
.. toctree::

0 commit comments

Comments
 (0)