|
| 1 | +Efficiency of writing "sparse" semantics for Adagrad |
| 2 | +==================================================== |
| 3 | + |
| 4 | +`Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code |
| 5 | +that were introduce while writing "sparse" semantics for Adagrad. |
| 6 | +But really the code doesn't use sparsity as a compression and optimization technique, |
| 7 | +it wants to use masked semantics. We worked around this by introducing one-off semantics and operators |
| 8 | +that encode this behavior while forcing users to be aware of storage details such as indices and values. |
| 9 | + |
| 10 | +In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero |
| 11 | +and when it is just used to compress zeros. |
| 12 | +We'll also compare and contrast this with equivalent code written using MaskedTensor. |
| 13 | +In the end the code snippets are repeat without additional comments to show the difference in brevity. |
| 14 | + |
| 15 | +Set up |
| 16 | +------ |
| 17 | + |
| 18 | +.. code-block:: python |
| 19 | +
|
| 20 | + import torch |
| 21 | + from torch.masked.maskedtensor import masked_tensor |
| 22 | +
|
| 23 | + # Some hyperparameters |
| 24 | + eps = 1e-10 |
| 25 | + clr = 0.1 |
| 26 | +
|
| 27 | + i = torch.tensor([[0, 1, 1], [2, 0, 2]]) |
| 28 | + v = torch.tensor([3, 4, 5], dtype=torch.float32) |
| 29 | + grad = torch.sparse_coo_tensor(i, v, [2, 4]) |
| 30 | +
|
| 31 | +Original sparse implementation |
| 32 | +------------------------------ |
| 33 | + |
| 34 | +First, let's break down the current implementation of |
| 35 | +`Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__ |
| 36 | +in PyTorch currently: |
| 37 | + |
| 38 | +.. code-block:: python |
| 39 | + :linenos: |
| 40 | +
|
| 41 | + def _make_sparse(grad, grad_indices, values): |
| 42 | + size = grad.size() |
| 43 | + if grad_indices.numel() == 0 or values.numel() == 0: |
| 44 | + return torch.empty_like(grad) |
| 45 | + return torch.sparse_coo_tensor(grad_indices, values, size) |
| 46 | +
|
| 47 | + # We don't support sparse gradients |
| 48 | + param = torch.arange(8).reshape(2, 4).float() |
| 49 | + state_sum = torch.full_like(param, 0.5) # initial value for state sum |
| 50 | +
|
| 51 | + grad = grad.coalesce() # the update is non-linear so indices must be unique |
| 52 | + grad_indices = grad._indices() |
| 53 | + grad_values = grad._values() |
| 54 | + # pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero |
| 55 | + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) |
| 56 | +
|
| 57 | + # We take care to make std sparse, even though state_sum clearly is not. |
| 58 | + # This means that we're only applying the gradient to parts of the state_sum |
| 59 | + # for which it is specified. This even drives the point home a lot more that |
| 60 | + # the passed gradient is not sparse, but masked. |
| 61 | + std = state_sum.sparse_mask(grad) |
| 62 | +
|
| 63 | + # We currently dodge all these concerns using the private method values. |
| 64 | + std_values = std._values().sqrt_().add_(eps) |
| 65 | +
|
| 66 | + # We currently don't support div for sparse Tensors because zero / zero is |
| 67 | + # not well defined. For a MaskedTensor undefined / undefined is undefined. |
| 68 | + param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| 69 | +
|
| 70 | +Line 24 is where we have a very important divergence. |
| 71 | + |
| 72 | +The addition of eps should technically be applied to all values but instead is only applied to specified values. |
| 73 | +Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values. |
| 74 | +If parts of the values of the gradient are zero, they are still included if materialized even though they |
| 75 | +could be compressed by other sparse storage layouts. |
| 76 | +This is technically quite brittle even though someone could argue that eps is always very small. |
| 77 | + |
| 78 | +Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme should cause densification, |
| 79 | +but we force it not to for performance. |
| 80 | +For this one-off case it is fine.. until we want to introduce new compression schemes |
| 81 | +such as CSR, BSR, or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each |
| 82 | +and write variations for gradients compressed using different storage formats, which is inconvenient. |
| 83 | + |
| 84 | +MaskedTensor sparse implementation |
| 85 | +---------------------------------- |
| 86 | + |
| 87 | +We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch. |
| 88 | +MaskedTensor proposes to call the semantic extension through sparsity masked. |
| 89 | +Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while |
| 90 | +MaskedTensor fixes that because it separates the storage from the semantics. |
| 91 | + |
| 92 | +Consider the above example using a masked gradient: |
| 93 | + |
| 94 | +.. code-block:: python |
| 95 | +
|
| 96 | + # Create an entirely new set of parameters to avoid errors |
| 97 | + param2 = torch.arange(8).reshape(2, 4).float() |
| 98 | + state_sum2 = torch.full_like(param, 0.5) # initial value for state sum |
| 99 | +
|
| 100 | + mask = (grad.to_dense() != 0).to_sparse() |
| 101 | + masked_grad = masked_tensor(grad, mask) |
| 102 | +
|
| 103 | + state_sum2 = state_sum2 + masked_grad.pow(2).get_data() |
| 104 | + std2 = masked_tensor(state_sum2.to_sparse(), mask) |
| 105 | +
|
| 106 | + # We can add support for in-place operations later. Notice how this doesn't |
| 107 | + # need to access any storage internals and is in general a lot shorter |
| 108 | + std2 = std2.sqrt().add(eps) |
| 109 | +
|
| 110 | + param2 = param2.add((masked_grad / std2).data(), alpha=-clr) |
| 111 | +
|
| 112 | +Note that the implementations look quite similar but just shorter. Much of the boilerplate code |
| 113 | +around `_make_sparse` (and needing to have a separate implementation per layout) is handled |
| 114 | +for you with :class:`MaskedTensor`. |
| 115 | + |
| 116 | +At this point, let's print both this version and original version for easier comparison: |
| 117 | + |
| 118 | + >>> state_sum |
| 119 | + >>> state_sum |
| 120 | + tensor([[ 0.5000, 0.5000, 9.5000, 0.5000], |
| 121 | + [16.5000, 0.5000, 25.5000, 0.5000]]) |
| 122 | + >>> state_sum2 |
| 123 | + tensor([[ 0.5000, 0.5000, 9.5000, 0.5000], |
| 124 | + [16.5000, 0.5000, 25.5000, 0.5000]]) |
| 125 | + >>> std |
| 126 | + tensor(indices=tensor([[0, 1, 1], |
| 127 | + [2, 0, 2]]), |
| 128 | + values=tensor([3.0822, 4.0620, 5.0498]), |
| 129 | + size=(2, 4), nnz=3, layout=torch.sparse_coo) |
| 130 | + >>> std2 |
| 131 | + MaskedTensor( |
| 132 | + [ |
| 133 | + [ --, --, 3.0822, --], |
| 134 | + [ 4.0620, --, 5.0498, --] |
| 135 | + ] |
| 136 | + ) |
| 137 | + >>> param |
| 138 | + tensor([[0.0000, 1.0000, 1.9027, 3.0000], |
| 139 | + [3.9015, 5.0000, 5.9010, 7.0000]]) |
| 140 | + >>> param2 |
| 141 | + tensor([[0.0000, 1.0000, 1.9027, 3.0000], |
| 142 | + [3.9015, 5.0000, 5.9010, 7.0000]]) |
| 143 | + |
| 144 | +which proves that the two implementations are indeed the same. |
| 145 | + |
| 146 | +Conclusion: Difference in Code |
| 147 | +------------------------------ |
| 148 | + |
| 149 | +For reference, this is the regular, dense code path without masked gradients or sparsity: |
| 150 | + |
| 151 | +.. code-block:: python |
| 152 | +
|
| 153 | + state_sum.addcmul_(grad, grad, value=1) |
| 154 | + std = state_sum.sqrt().add_(eps) |
| 155 | + param.addcdiv_(grad, std, value=-clr) |
| 156 | + |
| 157 | +The vanilla tensor implementation for sparse is: |
| 158 | + |
| 159 | +.. code-block:: python |
| 160 | +
|
| 161 | + grad = grad.coalesce() # the update is non-linear so indices must be unique |
| 162 | + grad_indices = grad._indices() |
| 163 | + grad_values = grad._values() |
| 164 | +
|
| 165 | + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout |
| 166 | + std = state_sum.sparse_mask(grad) |
| 167 | + std_values = std._values().sqrt_().add_(eps) |
| 168 | + param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| 169 | +
|
| 170 | +while :class:`MaskedTensor` minimizes the code to the snippet: |
| 171 | + |
| 172 | +.. code-block:: python |
| 173 | +
|
| 174 | + state_sum2 = state_sum2 + masked_grad.pow(2).data() |
| 175 | + std2 = masked_tensor(state_sum2.to_sparse(), mask) |
| 176 | + std2 = std2.sqrt().add(eps) |
| 177 | + param2 = param2.add((masked_grad / std2).data(), alpha=-clr) |
| 178 | +
|
| 179 | +One major goal of :class:`MaskedTensor` is to enable sparsity semantics and applications, such as this one. |
| 180 | +To learn more about using sparsity, you can find |
| 181 | +[this MaskedTensor sparsity tutorial](https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html). |
| 182 | +Currently, COO and CSR sparse layouts are supported, though there are immediate plans to add more. |
0 commit comments