|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +(Prototype) MaskedTensor Overview |
| 5 | +********************************* |
| 6 | +""" |
| 7 | + |
| 8 | +###################################################################### |
| 9 | +# This tutorial is designed to serve as a starting point for using MaskedTensors |
| 10 | +# and discuss its masking semantics. |
| 11 | +# |
| 12 | +# MaskedTensor serves as an extension to :class:`torch.Tensor` that provides the user with the ability to: |
| 13 | +# |
| 14 | +# * use any masked semantics (for example, variable length tensors, nan* operators, etc.) |
| 15 | +# * differentiation between 0 and NaN gradients |
| 16 | +# * various sparse applications (see tutorial below) |
| 17 | +# |
| 18 | +# For a more detailed introduction on what MaskedTensors are, please find the |
| 19 | +# `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__. |
| 20 | +# |
| 21 | +# Using MaskedTensor |
| 22 | +# ================== |
| 23 | +# |
| 24 | +# In this section we discuss how to use MaskedTensor including how to construct, access, the data |
| 25 | +# and mask, as well as indexing and slicing. |
| 26 | +# |
| 27 | +# Preparation |
| 28 | +# ----------- |
| 29 | +# |
| 30 | +# We'll begin by doing the necessary setup for the tutorial: |
| 31 | +# |
| 32 | + |
| 33 | +import torch |
| 34 | +from torch.masked import masked_tensor, as_masked_tensor |
| 35 | +import warnings |
| 36 | + |
| 37 | +# Disable prototype warnings and such |
| 38 | +warnings.filterwarnings(action='ignore', category=UserWarning) |
| 39 | + |
| 40 | +###################################################################### |
| 41 | +# Construction |
| 42 | +# ------------ |
| 43 | +# |
| 44 | +# There are a few different ways to construct a MaskedTensor: |
| 45 | +# |
| 46 | +# * The first way is to directly invoke the MaskedTensor class |
| 47 | +# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor` |
| 48 | +# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor` |
| 49 | +# |
| 50 | +# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`. |
| 51 | +# |
| 52 | +# Accessing the data and mask |
| 53 | +# --------------------------- |
| 54 | +# |
| 55 | +# The underlying fields in a MaskedTensor can be accessed through: |
| 56 | +# |
| 57 | +# * the :meth:`MaskedTensor.get_data` function |
| 58 | +# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid" |
| 59 | +# while ``False`` indicates "unspecified" or "invalid". |
| 60 | +# |
| 61 | +# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that |
| 62 | +# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to |
| 63 | +# return a Tensor with filled values. |
| 64 | +# |
| 65 | +# Indexing and slicing |
| 66 | +# -------------------- |
| 67 | +# |
| 68 | +# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing |
| 69 | +# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns: |
| 70 | +# |
| 71 | + |
| 72 | +data = torch.arange(24).reshape(2, 3, 4) |
| 73 | +mask = data % 2 == 0 |
| 74 | + |
| 75 | +print("data:\n", data) |
| 76 | +print("mask:\n", mask) |
| 77 | + |
| 78 | +###################################################################### |
| 79 | +# |
| 80 | + |
| 81 | +# float is used for cleaner visualization when being printed |
| 82 | +mt = masked_tensor(data.float(), mask) |
| 83 | + |
| 84 | +print("mt[0]:\n", mt[0]) |
| 85 | +print("mt[:, :, 2:4]:\n", mt[:, :, 2:4]) |
| 86 | + |
| 87 | +###################################################################### |
| 88 | +# Why is MaskedTensor useful? |
| 89 | +# =========================== |
| 90 | +# |
| 91 | +# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen |
| 92 | +# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings |
| 93 | +# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues. |
| 94 | +# |
| 95 | +# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today |
| 96 | +# and illustrate how :class:`MaskedTensor` can solve these problems. |
| 97 | +# |
| 98 | +# Distinguishing between 0 and NaN gradient |
| 99 | +# ----------------------------------------- |
| 100 | +# |
| 101 | +# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are |
| 102 | +# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value |
| 103 | +# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading |
| 104 | +# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing |
| 105 | +# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early |
| 106 | +# in the chain of operations a NaN value manifests). |
| 107 | +# |
| 108 | +# :class:`MaskedTensor` is the perfect solution for this! |
| 109 | +# |
| 110 | +# torch.where |
| 111 | +# ^^^^^^^^^^^ |
| 112 | +# |
| 113 | +# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations |
| 114 | +# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0 |
| 115 | +# or one from undefined gradients. Therefore, we remain consistent and mask out the results: |
| 116 | +# |
| 117 | +# Current result: |
| 118 | +# |
| 119 | + |
| 120 | +x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float) |
| 121 | +y = torch.where(x < 0, torch.exp(x), torch.ones_like(x)) |
| 122 | +y.sum().backward() |
| 123 | +x.grad |
| 124 | + |
| 125 | +###################################################################### |
| 126 | +# :class:`MaskedTensor` result: |
| 127 | +# |
| 128 | + |
| 129 | +x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100]) |
| 130 | +mask = x < 0 |
| 131 | +mx = masked_tensor(x, mask, requires_grad=True) |
| 132 | +my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True) |
| 133 | +y = torch.where(mask, torch.exp(mx), my) |
| 134 | +y.sum().backward() |
| 135 | +mx.grad |
| 136 | + |
| 137 | +###################################################################### |
| 138 | +# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where` |
| 139 | +# to mask out elements instead of setting them to zero. |
| 140 | +# |
| 141 | +# Another torch.where |
| 142 | +# ^^^^^^^^^^^^^^^^^^^ |
| 143 | +# |
| 144 | +# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example. |
| 145 | +# |
| 146 | +# Current result: |
| 147 | +# |
| 148 | + |
| 149 | +a = torch.randn((), requires_grad=True) |
| 150 | +b = torch.tensor(False) |
| 151 | +c = torch.ones(()) |
| 152 | +print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c)) |
| 153 | +print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a)) |
| 154 | + |
| 155 | +###################################################################### |
| 156 | +# :class:`MaskedTensor` result: |
| 157 | +# |
| 158 | + |
| 159 | +a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True) |
| 160 | +b = torch.tensor(False) |
| 161 | +c = torch.ones(()) |
| 162 | +print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c)) |
| 163 | +print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a)) |
| 164 | + |
| 165 | +###################################################################### |
| 166 | +# This issue is similar (and even links to the next issue below) in that it expresses frustration with |
| 167 | +# unexpected behavior because of the inability to differentiate "no gradient" vs "zero gradient", |
| 168 | +# which in turn makes working with other ops difficult to reason about. |
| 169 | +# |
| 170 | +# When using mask, x/0 yields NaN grad |
| 171 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 172 | +# |
| 173 | +# In `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__, the user proposes that |
| 174 | +# `x.grad` should be `[0, 1]` instead of the `[nan, 1]`, |
| 175 | +# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether. |
| 176 | +# |
| 177 | +# Current result: |
| 178 | +# |
| 179 | + |
| 180 | +x = torch.tensor([1., 1.], requires_grad=True) |
| 181 | +div = torch.tensor([0., 1.]) |
| 182 | +y = x/div # => y is [inf, 1] |
| 183 | +mask = (div != 0) # => mask is [0, 1] |
| 184 | +y[mask].backward() |
| 185 | +x.grad |
| 186 | + |
| 187 | +###################################################################### |
| 188 | +# :class:`MaskedTensor` result: |
| 189 | +# |
| 190 | + |
| 191 | +x = torch.tensor([1., 1.], requires_grad=True) |
| 192 | +div = torch.tensor([0., 1.]) |
| 193 | +y = x/div # => y is [inf, 1] |
| 194 | +mask = (div != 0) # => mask is [0, 1] |
| 195 | +loss = as_masked_tensor(y, mask) |
| 196 | +loss.sum().backward() |
| 197 | +x.grad |
| 198 | + |
| 199 | +###################################################################### |
| 200 | +# :func:`torch.nansum` and :func:`torch.nanmean` |
| 201 | +# ---------------------------------------------- |
| 202 | +# |
| 203 | +# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__, |
| 204 | +# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly. |
| 205 | +# |
| 206 | +# Current result: |
| 207 | +# |
| 208 | + |
| 209 | +a = torch.tensor([1., 2., float('nan')]) |
| 210 | +b = torch.tensor(1.0, requires_grad=True) |
| 211 | +c = a * b |
| 212 | +c1 = torch.nansum(c) |
| 213 | +bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) |
| 214 | +bgrad1 |
| 215 | + |
| 216 | +###################################################################### |
| 217 | +# :class:`MaskedTensor` result: |
| 218 | +# |
| 219 | + |
| 220 | +a = torch.tensor([1., 2., float('nan')]) |
| 221 | +b = torch.tensor(1.0, requires_grad=True) |
| 222 | +mt = masked_tensor(a, ~torch.isnan(a)) |
| 223 | +c = mt * b |
| 224 | +c1 = torch.sum(c) |
| 225 | +bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) |
| 226 | +bgrad1 |
| 227 | + |
| 228 | +###################################################################### |
| 229 | +# Safe Softmax |
| 230 | +# ------------ |
| 231 | +# |
| 232 | +# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__ |
| 233 | +# that arises frequently. In a nutshell, if there is an entire batch that is "masked out" |
| 234 | +# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`), |
| 235 | +# then this will result in NaNs, which can lead to training divergence. |
| 236 | +# |
| 237 | +# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup: |
| 238 | +# |
| 239 | + |
| 240 | +data = torch.randn(3, 3) |
| 241 | +mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]]) |
| 242 | +x = data.masked_fill(~mask, float('-inf')) |
| 243 | +mt = masked_tensor(data, mask) |
| 244 | +print("x:\n", x) |
| 245 | +print("mt:\n", mt) |
| 246 | + |
| 247 | +###################################################################### |
| 248 | +# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely |
| 249 | +# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`. |
| 250 | +# However, what we would really like is for the gradients to be masked out since they are unspecified and would be |
| 251 | +# invalid for training. |
| 252 | +# |
| 253 | +# PyTorch result: |
| 254 | +# |
| 255 | + |
| 256 | +x.softmax(0) |
| 257 | + |
| 258 | +###################################################################### |
| 259 | +# :class:`MaskedTensor` result: |
| 260 | +# |
| 261 | + |
| 262 | +mt.softmax(0) |
| 263 | + |
| 264 | +###################################################################### |
| 265 | +# Implementing missing torch.nan* operators |
| 266 | +# ----------------------------------------- |
| 267 | +# |
| 268 | +# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__, |
| 269 | +# there is a request to add additional operators to cover the various `torch.nan*` applications, |
| 270 | +# such as ``torch.nanmax``, ``torch.nanmin``, etc. |
| 271 | +# |
| 272 | +# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional |
| 273 | +# operators, we propose using :class:`MaskedTensor` instead. |
| 274 | +# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__, |
| 275 | +# we can use it as a comparison point: |
| 276 | +# |
| 277 | + |
| 278 | +x = torch.arange(16).float() |
| 279 | +y = x * x.fmod(4) |
| 280 | +z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros |
| 281 | + |
| 282 | +###################################################################### |
| 283 | +# |
| 284 | +print("y:\n", y) |
| 285 | +# z is just y with the zeros replaced with nan's |
| 286 | +print("z:\n", z) |
| 287 | + |
| 288 | +###################################################################### |
| 289 | +# |
| 290 | + |
| 291 | +print("y.mean():\n", y.mean()) |
| 292 | +print("z.nanmean():\n", z.nanmean()) |
| 293 | +# MaskedTensor successfully ignores the 0's |
| 294 | +print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0))) |
| 295 | + |
| 296 | +###################################################################### |
| 297 | +# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring |
| 298 | +# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the |
| 299 | +# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation, |
| 300 | +# and we already have support for the other operations listed in the issue. For example: |
| 301 | +# |
| 302 | + |
| 303 | +torch.argmin(masked_tensor(y, y != 0)) |
| 304 | + |
| 305 | +###################################################################### |
| 306 | +# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1. |
| 307 | +# |
| 308 | +# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent |
| 309 | +# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan`` |
| 310 | +# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result. |
| 311 | +# |
| 312 | + |
| 313 | +x = torch.empty(16).fill_(float('nan')) |
| 314 | +print("x:\n", x) |
| 315 | +print("torch.nanmean(x):\n", torch.nanmean(x)) |
| 316 | +print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x)))) |
| 317 | + |
| 318 | +###################################################################### |
| 319 | +# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value. |
| 320 | +# |
| 321 | +# Conclusion |
| 322 | +# ========== |
| 323 | +# |
| 324 | +# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their |
| 325 | +# value through a series of examples and issues that they've helped resolve. |
| 326 | +# |
| 327 | +# Further Reading |
| 328 | +# =============== |
| 329 | +# |
| 330 | +# To continue learning more, you can find our |
| 331 | +# `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__ |
| 332 | +# to see how MaskedTensor enables sparsity and the different storage formats we currently support. |
| 333 | +# |
0 commit comments