|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +(Prototype) Efficiently writing "sparse" semantics for Adagrad with MaskedTensor |
| 5 | +================================================================================ |
| 6 | +""" |
| 7 | + |
| 8 | +###################################################################### |
| 9 | +# `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code |
| 10 | +# that were introduced while writing "sparse" semantics for Adagrad. |
| 11 | +# But really, the code uses sparsity as a proxy for masked semantics rather than the intended use case of sparsity: |
| 12 | +# a compression and optimization technique, |
| 13 | +# Previously, we worked around the lack of formal masked semantics by introducing one-off semantics and operators |
| 14 | +# while forcing users to be aware of storage details such as indices and values. |
| 15 | +# |
| 16 | +# In particular, we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero |
| 17 | +# and when it is just used to compress zeros. |
| 18 | +# We'll also compare and contrast this with equivalent code written using MaskedTensor. |
| 19 | +# In the end the code snippets are repeated without additional comments to show the difference in brevity. |
| 20 | +# |
| 21 | +# Preparations |
| 22 | +# ------------ |
| 23 | +# |
| 24 | + |
| 25 | + |
| 26 | +import torch |
| 27 | + |
| 28 | +# Some hyperparameters |
| 29 | +eps = 1e-10 |
| 30 | +clr = 0.1 |
| 31 | + |
| 32 | +i = torch.tensor([[0, 1, 1], [2, 0, 2]]) |
| 33 | +v = torch.tensor([3, 4, 5], dtype=torch.float32) |
| 34 | +grad = torch.sparse_coo_tensor(i, v, [2, 4]) |
| 35 | + |
| 36 | +###################################################################### |
| 37 | +# Original sparse implementation |
| 38 | +# ------------------------------ |
| 39 | +# |
| 40 | +# First, let's break down the current implementation of |
| 41 | +# `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__ |
| 42 | +# in PyTorch: |
| 43 | +# |
| 44 | + |
| 45 | +def _make_sparse(grad, grad_indices, values): |
| 46 | + size = grad.size() |
| 47 | + if grad_indices.numel() == 0 or values.numel() == 0: |
| 48 | + return torch.empty_like(grad) |
| 49 | + return torch.sparse_coo_tensor(grad_indices, values, size) |
| 50 | + |
| 51 | +# We don't support sparse gradients |
| 52 | +param = torch.arange(8).reshape(2, 4).float() |
| 53 | +state_sum = torch.full_like(param, 0.5) # initial value for state sum |
| 54 | + |
| 55 | +grad = grad.coalesce() # the update is non-linear so indices must be unique |
| 56 | +grad_indices = grad._indices() |
| 57 | +grad_values = grad._values() |
| 58 | +# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero |
| 59 | +state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) |
| 60 | + |
| 61 | +# We take care to make std sparse, even though state_sum clearly is not. |
| 62 | +# This means that we're only applying the gradient to parts of the state_sum |
| 63 | +# for which it is specified. This further drives the point home that the passed gradient is not sparse, but masked. |
| 64 | +std = state_sum.sparse_mask(grad) |
| 65 | + |
| 66 | +# We currently dodge all these concerns using the private method values. |
| 67 | +std_values = std._values().sqrt_().add_(eps) |
| 68 | + |
| 69 | +# We currently don't support div for sparse Tensors because zero / zero is |
| 70 | +# not well defined. For a MaskedTensor undefined / undefined is undefined. |
| 71 | +param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| 72 | + |
| 73 | +###################################################################### |
| 74 | +# `std = state_sum.sparse_mask(grad)` is where we have a very important divergence. |
| 75 | +# |
| 76 | +# The addition of eps should technically be applied to all values but instead is only applied to specified values. |
| 77 | +# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values. |
| 78 | +# If parts of the values of the gradient are zero, they are still included if materialized even though they |
| 79 | +# could be compressed by other sparse storage layouts. |
| 80 | +# This is technically quite brittle even though someone could argue that eps is always very small. |
| 81 | +# |
| 82 | +# Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme should cause densification, |
| 83 | +# but we force it not to for performance. |
| 84 | +# For this one-off case it is fine.. until we want to introduce new compression schemes |
| 85 | +# such as CSR, BSR, or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each |
| 86 | +# and write variations for gradients compressed using different storage formats, which is inconvenient. |
| 87 | +# |
| 88 | +# MaskedTensor sparse implementation |
| 89 | +# ---------------------------------- |
| 90 | +# |
| 91 | +# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch. |
| 92 | +# MaskedTensor proposes to call the semantic extension through sparsity masked. |
| 93 | +# Currently we can't have dense semantics with sparse storage or masked semantics with dense storage; |
| 94 | +# MaskedTensor enables these ideas by purposefully separating the storage from the semantics. |
| 95 | +# |
| 96 | +# Consider the above example using a masked gradient: |
| 97 | +# |
| 98 | + |
| 99 | +# Let's now import MaskedTensor! |
| 100 | +from torch.masked import masked_tensor |
| 101 | + |
| 102 | +# Create an entirely new set of parameters to avoid errors |
| 103 | +param2 = torch.arange(8).reshape(2, 4).float() |
| 104 | +state_sum2 = torch.full_like(param, 0.5) # initial value for state sum |
| 105 | + |
| 106 | +mask = (grad.to_dense() != 0).to_sparse() |
| 107 | +masked_grad = masked_tensor(grad, mask) |
| 108 | + |
| 109 | +state_sum2 = state_sum2 + masked_grad.pow(2).get_data() |
| 110 | +std2 = masked_tensor(state_sum2.to_sparse(), mask) |
| 111 | + |
| 112 | +# We can add support for in-place operations later. Notice how this doesn't |
| 113 | +# need to access any storage internals and is in general a lot shorter |
| 114 | +std2 = std2.sqrt().add(eps) |
| 115 | + |
| 116 | +param2 = param2.add((masked_grad / std2).data(), alpha=-clr) |
| 117 | + |
| 118 | +###################################################################### |
| 119 | +# Note that the implementations look quite similar, but the MaskedTensor implementation is shorter and simpler. |
| 120 | +# For example, much of the boilerplate code around ``_make_sparse`` |
| 121 | +# (and needing to have a separate implementation per layout) is handled for the user with :class:`MaskedTensor`. |
| 122 | +# |
| 123 | +# At this point, let's print both this version and original version for easier comparison: |
| 124 | +# |
| 125 | + |
| 126 | +print("state_sum:\n", state_sum) |
| 127 | +print("state_sum2:\n", state_sum2) |
| 128 | + |
| 129 | +print("std:\n", std) |
| 130 | +print("std2:\n", std2) |
| 131 | + |
| 132 | +print("param:\n", param) |
| 133 | +print("param2:\n", param2) |
| 134 | + |
| 135 | +###################################################################### |
| 136 | +# which proves that the two implementations are indeed the same. |
| 137 | +# |
| 138 | +# Conclusion: Simpler Code with MaskedTensor |
| 139 | +# ------------------------------------------ |
| 140 | +# |
| 141 | +# For reference, this is the regular, dense code path without masked gradients or sparsity: |
| 142 | +# |
| 143 | + |
| 144 | +.. code-block:: python |
| 145 | + |
| 146 | + state_sum.addcmul_(grad, grad, value=1) |
| 147 | + std = state_sum.sqrt().add_(eps) |
| 148 | + param.addcdiv_(grad, std, value=-clr) |
| 149 | + |
| 150 | +The vanilla tensor implementation for sparse is: |
| 151 | + |
| 152 | +.. code-block:: python |
| 153 | + |
| 154 | + grad = grad.coalesce() # the update is non-linear so indices must be unique |
| 155 | + grad_indices = grad._indices() |
| 156 | + grad_values = grad._values() |
| 157 | + |
| 158 | + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout |
| 159 | + std = state_sum.sparse_mask(grad) |
| 160 | + std_values = std._values().sqrt_().add_(eps) |
| 161 | + param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| 162 | + |
| 163 | +while :class:`MaskedTensor` minimizes the code to the snippet: |
| 164 | + |
| 165 | +.. code-block:: python |
| 166 | + |
| 167 | + state_sum2 = state_sum2 + masked_grad.pow(2).data() |
| 168 | + std2 = masked_tensor(state_sum2.to_sparse(), mask) |
| 169 | + std2 = std2.sqrt().add(eps) |
| 170 | + param2 = param2.add((masked_grad / std2).data(), alpha=-clr) |
| 171 | + |
| 172 | +One major goal of :class:`MaskedTensor` is to enable sparsity semantics and applications, such as this one. |
| 173 | +To learn more about using sparsity, you can find |
| 174 | +[this MaskedTensor sparsity tutorial](https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html). |
| 175 | +Currently, COO and CSR sparse layouts are supported, though there are immediate plans to add more. |
0 commit comments