|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +(Prototype) Efficiently writing "sparse" semantics for Adagrad with MaskedTensor |
| 5 | +================================================================================ |
| 6 | +""" |
| 7 | + |
| 8 | +###################################################################### |
| 9 | +# Before working through this tutorial, please review the MaskedTensor |
| 10 | +# `Overview <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`__ and |
| 11 | +# `Sparsity <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__ tutorials. |
| 12 | +# |
| 13 | +# Introduction and Motivation |
| 14 | +# --------------------------- |
| 15 | +# `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code |
| 16 | +# that were introduced while writing "sparse" semantics for Adagrad, but really, |
| 17 | +# the code uses sparsity as a proxy for masked semantics rather than the intended use case of sparsity: |
| 18 | +# a compression and optimization technique. |
| 19 | +# Previously, we worked around the lack of formal masked semantics by introducing one-off semantics and operators |
| 20 | +# while forcing users to be aware of storage details such as indices and values. |
| 21 | +# |
| 22 | +# Now that we have masked semantics, we are better equipped to point out when sparsity is used as a semantic extension. |
| 23 | +# We'll also compare and contrast this with equivalent code written using MaskedTensor. |
| 24 | +# In the end the code snippets are repeated without additional comments to show the difference in brevity. |
| 25 | +# |
| 26 | +# Preparation |
| 27 | +# ----------- |
| 28 | +# |
| 29 | + |
| 30 | +import torch |
| 31 | +import warnings |
| 32 | + |
| 33 | +# Disable prototype warnings and such |
| 34 | +warnings.filterwarnings(action='ignore', category=UserWarning) |
| 35 | + |
| 36 | +# Some hyperparameters |
| 37 | +eps = 1e-10 |
| 38 | +clr = 0.1 |
| 39 | + |
| 40 | +i = torch.tensor([[0, 1, 1], [2, 0, 2]]) |
| 41 | +v = torch.tensor([3, 4, 5], dtype=torch.float32) |
| 42 | +grad = torch.sparse_coo_tensor(i, v, [2, 4]) |
| 43 | + |
| 44 | +###################################################################### |
| 45 | +# Simpler Code with MaskedTensor |
| 46 | +# ------------------------------ |
| 47 | +# |
| 48 | +# Before we get too far in the weeds, let's introduce the problem a bit more concretely. We will be taking a look |
| 49 | +# into the `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__ |
| 50 | +# implementation in PyTorch with the ultimate goal of simplifying and more faithfully representing the masked approach. |
| 51 | +# |
| 52 | +# For reference, this is the regular, dense code path without masked gradients or sparsity: |
| 53 | +# |
| 54 | +# .. code-block:: python |
| 55 | +# |
| 56 | +# state_sum.addcmul_(grad, grad, value=1) |
| 57 | +# std = state_sum.sqrt().add_(eps) |
| 58 | +# param.addcdiv_(grad, std, value=-clr) |
| 59 | +# |
| 60 | +# The vanilla tensor implementation for sparse is: |
| 61 | +# |
| 62 | +# .. code-block:: python |
| 63 | +# |
| 64 | +# def _make_sparse(grad, grad_indices, values): |
| 65 | +# size = grad.size() |
| 66 | +# if grad_indices.numel() == 0 or values.numel() == 0: |
| 67 | +# return torch.empty_like(grad) |
| 68 | +# return torch.sparse_coo_tensor(grad_indices, values, size) |
| 69 | +# |
| 70 | +# grad = grad.coalesce() # the update is non-linear so indices must be unique |
| 71 | +# grad_indices = grad._indices() |
| 72 | +# grad_values = grad._values() |
| 73 | +# |
| 74 | +# state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout |
| 75 | +# std = state_sum.sparse_mask(grad) |
| 76 | +# std_values = std._values().sqrt_().add_(eps) |
| 77 | +# param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| 78 | +# |
| 79 | +# while :class:`MaskedTensor` minimizes the code to the snippet: |
| 80 | +# |
| 81 | +# .. code-block:: python |
| 82 | +# |
| 83 | +# state_sum2 = state_sum2 + masked_grad.pow(2).get_data() |
| 84 | +# std2 = masked_tensor(state_sum2.to_sparse(), mask) |
| 85 | +# std2 = std2.sqrt().add(eps) |
| 86 | +# param2 = param2.add((masked_grad / std2).get_data(), alpha=-clr) |
| 87 | +# |
| 88 | +# In this tutorial, we will go through each implementation line by line, but at first glance, we can notice |
| 89 | +# (1) how much shorter the MaskedTensor implementation is, and |
| 90 | +# (2) how it avoids conversions between dense and sparse tensors. |
| 91 | +# |
| 92 | + |
| 93 | +###################################################################### |
| 94 | +# Original Sparse Implementation |
| 95 | +# ------------------------------ |
| 96 | +# |
| 97 | +# Now, let's break down the code with some inline comments: |
| 98 | +# |
| 99 | + |
| 100 | +def _make_sparse(grad, grad_indices, values): |
| 101 | + size = grad.size() |
| 102 | + if grad_indices.numel() == 0 or values.numel() == 0: |
| 103 | + return torch.empty_like(grad) |
| 104 | + return torch.sparse_coo_tensor(grad_indices, values, size) |
| 105 | + |
| 106 | +# We don't support sparse gradients |
| 107 | +param = torch.arange(8).reshape(2, 4).float() |
| 108 | +state_sum = torch.full_like(param, 0.5) # initial value for state sum |
| 109 | + |
| 110 | +grad = grad.coalesce() # the update is non-linear so indices must be unique |
| 111 | +grad_indices = grad._indices() |
| 112 | +grad_values = grad._values() |
| 113 | +# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero |
| 114 | +state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) |
| 115 | + |
| 116 | +# We take care to make std sparse, even though state_sum clearly is not. |
| 117 | +# This means that we're only applying the gradient to parts of the state_sum |
| 118 | +# for which it is specified. This further drives the point home that the passed gradient is not sparse, but masked. |
| 119 | +# We currently dodge all these concerns using the private method `_values`. |
| 120 | +std = state_sum.sparse_mask(grad) |
| 121 | +std_values = std._values().sqrt_().add_(eps) |
| 122 | + |
| 123 | +# Note here that we currently don't support div for sparse Tensors because zero / zero is not well defined, |
| 124 | +# so we're forced to perform `grad_values / std_values` outside the sparse semantic and then convert back to a |
| 125 | +# sparse tensor with `make_sparse`. |
| 126 | +# We'll later see that MaskedTensor will actually handle these operations for us as well as properly denote |
| 127 | +# undefined / undefined = undefined! |
| 128 | +param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| 129 | + |
| 130 | +###################################################################### |
| 131 | +# The third to last line -- `std = state_sum.sparse_mask(grad)` -- is where we have a very important divergence. |
| 132 | +# |
| 133 | +# The addition of eps should technically be applied to all values but instead is only applied to specified values. |
| 134 | +# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values. |
| 135 | +# If parts of the values of the gradient are zero, they are still included if materialized even though they |
| 136 | +# could be compressed by other sparse storage layouts. This is theoretically quite brittle! |
| 137 | +# That said, one could argue that eps is always very small, so it might not matter so much in practice. |
| 138 | +# |
| 139 | +# Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme |
| 140 | +# should cause densification, but we force it not to for performance. |
| 141 | +# For this one-off case it is fine.. until we want to introduce new compression scheme, such as |
| 142 | +# `CSC <https://pytorch.org/docs/master/sparse.html#sparse-csc-docs>`__, |
| 143 | +# `BSR <https://pytorch.org/docs/master/sparse.html#sparse-bsr-docs>`__, |
| 144 | +# or `BSC <https://pytorch.org/docs/master/sparse.html#sparse-bsc-docs>`__. |
| 145 | +# We will then need to introduce separate Tensor types for each and write variations for gradients compressed |
| 146 | +# using different storage formats, which is inconvenient and not quite scalable nor clean. |
| 147 | +# |
| 148 | +# MaskedTensor Sparse Implementation |
| 149 | +# ---------------------------------- |
| 150 | +# |
| 151 | +# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch. |
| 152 | +# MaskedTensor proposes to disentangle the sparsity optimization from the semantic extension; for example, |
| 153 | +# currently we can't have dense semantics with sparse storage or masked semantics with dense storage. |
| 154 | +# MaskedTensor enables these ideas by purposefully separating the storage from the semantics. |
| 155 | +# |
| 156 | +# Consider the above example using a masked gradient: |
| 157 | +# |
| 158 | + |
| 159 | +# Let's now import MaskedTensor! |
| 160 | +from torch.masked import masked_tensor |
| 161 | + |
| 162 | +# Create an entirely new set of parameters to avoid errors |
| 163 | +param2 = torch.arange(8).reshape(2, 4).float() |
| 164 | +state_sum2 = torch.full_like(param, 0.5) # initial value for state sum |
| 165 | + |
| 166 | +mask = (grad.to_dense() != 0).to_sparse() |
| 167 | +masked_grad = masked_tensor(grad, mask) |
| 168 | + |
| 169 | +state_sum2 = state_sum2 + masked_grad.pow(2).get_data() |
| 170 | +std2 = masked_tensor(state_sum2.to_sparse(), mask) |
| 171 | + |
| 172 | +# We can add support for in-place operations later. Notice how this doesn't |
| 173 | +# need to access any storage internals and is in general a lot shorter |
| 174 | +std2 = std2.sqrt().add(eps) |
| 175 | + |
| 176 | +param2 = param2.add((masked_grad / std2).get_data(), alpha=-clr) |
| 177 | + |
| 178 | +###################################################################### |
| 179 | +# Note that the implementations look quite similar, but the MaskedTensor implementation is shorter and simpler. |
| 180 | +# In particular, much of the boilerplate code around ``_make_sparse`` |
| 181 | +# (and needing to have a separate implementation per layout) is handled for the user with :class:`MaskedTensor`. |
| 182 | +# |
| 183 | +# At this point, let's print both this version and original version for easier comparison: |
| 184 | +# |
| 185 | + |
| 186 | +print("state_sum:\n", state_sum) |
| 187 | +print("state_sum2:\n", state_sum2) |
| 188 | + |
| 189 | +###################################################################### |
| 190 | +# |
| 191 | + |
| 192 | +print("std:\n", std) |
| 193 | +print("std2:\n", std2) |
| 194 | + |
| 195 | +###################################################################### |
| 196 | +# |
| 197 | + |
| 198 | +print("param:\n", param) |
| 199 | +print("param2:\n", param2) |
| 200 | + |
| 201 | +###################################################################### |
| 202 | +# Conclusion |
| 203 | +# ---------- |
| 204 | +# |
| 205 | +# In this tutorial, we've discussed how native masked semantics can enable a cleaner developer experience for |
| 206 | +# Adagrad's existing implementation in PyTorch, which used sparsity as a proxy for writing masked semantics. |
| 207 | +# But more importantly, allowing masked semantics to be a first class citizen through MaskedTensor |
| 208 | +# removes the reliance on sparsity or unreliable hacks to mimic masking, thereby allowing for proper independence |
| 209 | +# and development, while enabling sparse semantics, such as this one. |
| 210 | +# |
| 211 | +# Further Reading |
| 212 | +# --------------- |
| 213 | +# |
| 214 | +# To continue learning more, you can find our final review (for now) on |
| 215 | +# `MaskedTensor Advanced Semantics <https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics.html>`__ |
| 216 | +# to see some of the differences in design decisions between :class:`MaskedTensor` and NumPy's MaskedArray, as well |
| 217 | +# as reduction semantics. |
| 218 | +# |
0 commit comments