1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Efficiency of writing "sparse" semantics for Adagrad
5
+ ====================================================
6
+
7
+ `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
8
+ that were introduce while writing "sparse" semantics for Adagrad.
9
+ But really the code doesn't use sparsity as a compression and optimization technique,
10
+ it wants to use masked semantics. We worked around this by introducing one-off semantics and operators
11
+ that encode this behavior while forcing users to be aware of storage details such as indices and values.
12
+
13
+ In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
14
+ and when it is just used to compress zeros.
15
+ We'll also compare and contrast this with equivalent code written using MaskedTensor.
16
+ In the end the code snippets are repeat without additional comments to show the difference in brevity.
17
+
18
+ """ "
19
+
20
+ import torch
21
+ from torch .masked .maskedtensor import masked_tensor
22
+
23
+ ######################################################################
24
+ # Original sparse implementation
25
+ # ------------------------------
26
+ #
27
+ # First, let's look at the current implementation of
28
+ # `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
29
+ #
30
+
31
+ def _make_sparse (grad , grad_indices , values ):
32
+ size = grad .size ()
33
+ if grad_indices .numel () == 0 or values .numel () == 0 :
34
+ return torch .empty_like (grad )
35
+ return torch .sparse_coo_tensor (grad_indices , values , size )
36
+
37
+ # Some hyperparameters
38
+ eps = 1e-10
39
+ clr = 0.1
40
+
41
+ # We don't support sparse gradients
42
+ param = torch .arange (8 ).reshape (2 , 4 ).float ()
43
+ i = torch .tensor ([[0 , 1 , 1 ],
44
+ [2 , 0 , 2 ]])
45
+ v = torch .tensor ([3 , 4 , 5 ], dtype = torch .float32 )
46
+ grad = torch .sparse_coo_tensor (i , v , [2 , 4 ])
47
+ state_sum = torch .full_like (param , 0.5 ) # initial value for state sum
48
+
49
+ print ("param:\n " , param )
50
+ print ("grad:\n " , grad .to_dense ())
51
+ print ("state_sum:\n " , state_sum )
52
+
53
+ ######################################################################
54
+ #
55
+
56
+ state_sum = torch .full_like (param , 0.5 ) # initial value for state sum
57
+ print (state_sum )
58
+
59
+ grad = grad .coalesce () # the update is non-linear so indices must be unique
60
+ grad_indices = grad ._indices ()
61
+ grad_values = grad ._values ()
62
+
63
+ # pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
64
+ state_sum .add_ (_make_sparse (grad , grad_indices , grad_values .pow (2 )))
65
+ # We take care to make std sparse, even though state_sum clearly is not.
66
+ # This means that we're only applying the gradient to parts of the state_sum
67
+ # for which it is specified. This even drives the point home a lot more that
68
+ # the passed gradient is not sparse, but masked.
69
+ std = state_sum .sparse_mask (grad )
70
+ print ("state_sum:\n " , state_sum )
71
+ print ("std:\n " , std .to_dense ())
72
+
73
+ ######################################################################
74
+ # This is where we have a very important divergence.
75
+ # The addition of eps should technically be applied to all values, but instead is only applied to specified values.
76
+ # Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
77
+ # If parts of the values of the gradient are zero they are still included if materialized.
78
+ # Even though they could be compressed by other sparse storage layouts.
79
+ # This is technically quite brittle even though someone could argue that eps is always very small.
80
+ #
81
+ # Moreover, an implementation add_ for sparsity as a storage layout and compression scheme should cause densification,
82
+ # but we force it not to.
83
+ # For this one-off case it is fine until we want to introduce new compression schemes
84
+ # such as CSR, BSR or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each
85
+ # and write variations for gradients compressed using different storage formats.
86
+ #
87
+
88
+ # We currently dodge all these concerns using the private method values.
89
+ std_values = std ._values ().sqrt_ ().add_ (eps )
90
+
91
+ # We currently don't support div for sparse Tensors because zero / zero is
92
+ # not well defined. For a MaskedTensor undefined / undefined is undefined.
93
+ param .add_ (_make_sparse (grad , grad_indices , grad_values / std_values ), alpha = - clr )
94
+ print ("param:\n " , param )
95
+
96
+ ######################################################################
97
+ # MaskedTensor sparse implementation
98
+ # ----------------------------------
99
+ #
100
+ # We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
101
+ # MaskedTensor proposes to call the semantic extension through sparsity masked.
102
+ # Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while
103
+ # MaskedTensor fixes that because it separates the storage from the semantics.
104
+ # Consider the above example using a masked gradient:
105
+ #
106
+
107
+ # Create an entirely new set of parameters to avoid errors
108
+ param2 = torch .arange (8 ).reshape (2 , 4 ).float ()
109
+ state_sum2 = torch .full_like (param , 0.5 ) # initial value for state sum
110
+
111
+ mask = (grad .to_dense () != 0 ).to_sparse ()
112
+ masked_grad = masked_tensor (grad , mask )
113
+ print ("masked_grad:\n " , masked_grad )
114
+
115
+ ######################################################################
116
+ #
117
+
118
+ state_sum2 = state_sum2 + masked_grad .pow (2 ).data ()
119
+ std2 = masked_tensor (state_sum2 .to_sparse (), mask )
120
+
121
+ # Let's print both this version and the regular version for easier comparison
122
+ print ("state_sum:\n " , state_sum )
123
+ print ("std:\n " , std )
124
+ print ("state_sum2:\n " , state_sum2 )
125
+ print ("std2:\n " , std2 )
126
+
127
+ ######################################################################
128
+ #
129
+
130
+ # We can add support for in-place operations later. Notice how this doesn't
131
+ # need to access any storage internals and is in general a lot shorter
132
+ std2 = std2 .sqrt ().add (eps )
133
+
134
+ print ("std:\n " , std )
135
+ print ("std2:\n " , std2 )
136
+
137
+ # .data() indeed returns a sparse tensor
138
+ param2 = param2 .add ((masked_grad / std2 ).data (), alpha = - clr )
139
+ print ("param2:\n " , param2 )
140
+
141
+ ######################################################################
142
+ # Conclusion: Difference in code
143
+ # ------------------------------
144
+ #
145
+ # For reference, this is the regular, dense code path without masked gradients or sparsity:
146
+ # ::
147
+ #
148
+ # state_sum.addcmul_(grad, grad, value=1)
149
+ # std = state_sum.sqrt().add_(eps)
150
+ # param.addcdiv_(grad, std, value=-clr)
151
+ #
152
+ # The vanilla tensor implementation for sparse is:
153
+ #
154
+
155
+ grad = grad .coalesce () # the update is non-linear so indices must be unique
156
+ grad_indices = grad ._indices ()
157
+ grad_values = grad ._values ()
158
+ size = grad .size ()
159
+
160
+ state_sum .add_ (_make_sparse (grad , grad_indices , grad_values .pow (2 )))
161
+ std = state_sum .sparse_mask (grad )
162
+ std_values = std ._values ().sqrt_ ().add_ (eps )
163
+ param .add_ (_make_sparse (grad , grad_indices , grad_values / std_values ), alpha = - clr )
164
+
165
+ ######################################################################
166
+ # while MaskedTensor minimizes the code to the snippet:
167
+ #
168
+
169
+ state_sum2 = state_sum2 + masked_grad .pow (2 ).data ()
170
+ std2 = masked_tensor (state_sum2 .to_sparse (), mask )
171
+ std2 = std2 .sqrt ().add (eps )
172
+ param2 = param2 .add ((masked_grad / std2 ).data (), alpha = - clr )
173
+
174
+ ######################################################################
175
+ # And for good measure, let's make sure the parameters match:
176
+ #
177
+
178
+ print ("param:\n " , param )
179
+ print ("param2:\n " , param2 )
180
+
0 commit comments