|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +(Prototype) MaskedTensor Overview |
| 5 | +===================== |
| 6 | +**Author**: `George Qi <https://github.com/george-qi>`_ |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# This tutorial is designed to serve as a starting point for using MaskedTensors |
| 11 | +# and discuss its masking semantics. |
| 12 | +# |
| 13 | + |
| 14 | +###################################################################### |
| 15 | +# Using MaskedTensor |
| 16 | +# ++++++++++++++++++ |
| 17 | +# |
| 18 | +# Construction |
| 19 | +# ------------ |
| 20 | +# |
| 21 | +# There are a few different ways to construct a MaskedTensor: |
| 22 | +# |
| 23 | +# * The first way is to directly invoke the MaskedTensor class |
| 24 | +# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor` |
| 25 | +# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor` |
| 26 | +# |
| 27 | +# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`. |
| 28 | +# |
| 29 | +# Accessing the data and mask |
| 30 | +# --------------------------- |
| 31 | +# |
| 32 | +# The underlying fields in a MaskedTensor can be accessed through: |
| 33 | +# |
| 34 | +# * the :meth:`MaskedTensor.get_data` function |
| 35 | +# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid" |
| 36 | +# while ``False`` indicates "unspecified" or "invalid". |
| 37 | +# |
| 38 | +# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that |
| 39 | +# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to |
| 40 | +# return a Tensor with filled values. |
| 41 | +# |
| 42 | +# Indexing and slicing |
| 43 | +# -------------------- |
| 44 | +# |
| 45 | +# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing |
| 46 | +# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns: |
| 47 | +# |
| 48 | + |
| 49 | +import torch |
| 50 | +from torch.masked import masked_tensor |
| 51 | + |
| 52 | +data = torch.arange(24).reshape(2, 3, 4) |
| 53 | +mask = data % 2 == 0 |
| 54 | + |
| 55 | +print("data\n", data) |
| 56 | +print("mask\n", mask) |
| 57 | + |
| 58 | +# float is used for cleaner visualization when being printed |
| 59 | +mt = masked_tensor(data.float(), mask) |
| 60 | + |
| 61 | +print ("mt[0]:\n", mt[0]) |
| 62 | +print ("mt[:, :, 2:4]", mt[:, :, 2:4]) |
| 63 | + |
| 64 | +###################################################################### |
| 65 | +# Why is MaskedTensor useful? |
| 66 | +# +++++++++++++++++++++++++++ |
| 67 | +# |
| 68 | +# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen |
| 69 | +# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings |
| 70 | +# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues. |
| 71 | +# |
| 72 | +# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today |
| 73 | +# and illustrate how :class:`MaskedTensor` can solve these problems. |
| 74 | +# |
| 75 | +# Distinguishing between 0 and NaN gradient |
| 76 | +# ----------------------------------------- |
| 77 | +# |
| 78 | +# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are |
| 79 | +# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value |
| 80 | +# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading |
| 81 | +# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing |
| 82 | +# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early |
| 83 | +# in the chain of operations a NaN value manifests). |
| 84 | +# |
| 85 | +# :class:`MaskedTensor` is the perfect solution for this! |
| 86 | +# |
| 87 | +# `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`_ -- :func:`torch.where` |
| 88 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 89 | +# |
| 90 | +# Current result: |
| 91 | +# |
| 92 | + |
| 93 | +x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float) |
| 94 | +y = torch.where(x < 0, torch.exp(x), torch.ones_like(x)) |
| 95 | +y.sum().backward() |
| 96 | +x.grad |
| 97 | + |
| 98 | +###################################################################### |
| 99 | +# :class:`MaskedTensor` result: |
| 100 | +# |
| 101 | + |
| 102 | +x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100]) |
| 103 | +mask = x < 0 |
| 104 | +mx = masked_tensor(x, mask, requires_grad=True) |
| 105 | +my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True) |
| 106 | +y = torch.where(mask, torch.exp(mx), my) |
| 107 | +y.sum().backward() |
| 108 | +mx.grad |
| 109 | + |
| 110 | +###################################################################### |
| 111 | +# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where` |
| 112 | +# to mask out elements instead of setting them to zero. |
| 113 | +# |
| 114 | +# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`_ -- another :func:`torch.where` |
| 115 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 116 | +# |
| 117 | +# Current result: |
| 118 | +# |
| 119 | + |
| 120 | +a = torch.randn((), requires_grad=True) |
| 121 | +b = torch.tensor(False) |
| 122 | +c = torch.ones(()) |
| 123 | +print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c)) |
| 124 | +print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a)) |
| 125 | + |
| 126 | +###################################################################### |
| 127 | +# :class:`MaskedTensor` result: |
| 128 | +# |
| 129 | + |
| 130 | +a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True) |
| 131 | +b = torch.tensor(False) |
| 132 | +c = torch.ones(()) |
| 133 | +print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c)) |
| 134 | +print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a)) |
| 135 | + |
| 136 | +###################################################################### |
| 137 | +# This issue is similar (and even links to the next issue below) in that it expresses frustration with unexpected behavior |
| 138 | +# because of the inability to differentiate "no gradient" vs "zero gradient", which in turn makes |
| 139 | +# working with other ops difficult and unreliable. |
| 140 | +# |
| 141 | +# `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`_ -- when using mask, x/0 yields NaN grad |
| 142 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 143 | +# |
| 144 | +# Current result: |
| 145 | +# |
| 146 | + |
| 147 | +x = torch.tensor([1., 1.], requires_grad=True) |
| 148 | +div = torch.tensor([0., 1.]) |
| 149 | +y = x/div # => y is [inf, 1] |
| 150 | +mask = (div != 0) # => mask is [0, 1] |
| 151 | +y[mask].backward() |
| 152 | +x.grad |
| 153 | + tensor([nan, 1.]) |
| 154 | + |
| 155 | +###################################################################### |
| 156 | +# :class:`MaskedTensor` result: |
| 157 | +# |
| 158 | + |
| 159 | +x = torch.tensor([1., 1.], requires_grad=True) |
| 160 | +div = torch.tensor([0., 1.]) |
| 161 | +y = x/div # => y is [inf, 1] |
| 162 | + >>> |
| 163 | +mask = (div != 0) # => mask is [0, 1] |
| 164 | +loss = as_masked_tensor(y, mask) |
| 165 | +loss.sum().backward() |
| 166 | +x.grad |
| 167 | + |
| 168 | +###################################################################### |
| 169 | +# Linked in the issue above, this issue proposes that `x.grad` should be `[0, 1]` instead of the `[nan, 1]`, |
| 170 | +# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether. |
| 171 | +# |
| 172 | +# `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`_ -- :func:`torch.nansum` and :func:`torch.nanmean` |
| 173 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 174 | +# |
| 175 | +# Current result: |
| 176 | +# |
| 177 | + |
| 178 | +a = torch.tensor([1., 2., float('nan')]) |
| 179 | +b = torch.tensor(1.0, requires_grad=True) |
| 180 | +c = a * b |
| 181 | +c1 = torch.nansum(c) |
| 182 | +bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) |
| 183 | +bgrad1 |
| 184 | + |
| 185 | +###################################################################### |
| 186 | +# :class:`MaskedTensor` result: |
| 187 | +# |
| 188 | + |
| 189 | +a = torch.tensor([1., 2., float('nan')]) |
| 190 | +b = torch.tensor(1.0, requires_grad=True) |
| 191 | +mt = masked_tensor(a, ~torch.isnan(a)) |
| 192 | +c = mt * b |
| 193 | +c1 = torch.sum(c) |
| 194 | +bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) |
| 195 | +bgrad1 |
| 196 | + |
| 197 | +###################################################################### |
| 198 | +# Here, the gradient doesn't even calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles |
| 199 | +# it correctly. |
| 200 | +# |
| 201 | +# Safe Softmax |
| 202 | +# ------------ |
| 203 | +# |
| 204 | +# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_ |
| 205 | +# that arises frequently. In a nutshell, if there is an entire batch that is "masked out" |
| 206 | +# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`), |
| 207 | +# then this will result in NaNs, which can lead to training divergence. |
| 208 | +# |
| 209 | +# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup: |
| 210 | +# |
| 211 | + |
| 212 | +data = torch.randn(3, 3) |
| 213 | +mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]]) |
| 214 | +x = data.masked_fill(~mask, float('-inf')) |
| 215 | +mt = masked_tensor(data, mask) |
| 216 | +print("x:\n", x) |
| 217 | +print("mt:\n", mt) |
| 218 | + |
| 219 | +###################################################################### |
| 220 | +# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely |
| 221 | +# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`. |
| 222 | +# However, what we would really like is for the gradients to be masked out since they are unspecified and would be |
| 223 | +# invalid for training. |
| 224 | +# |
| 225 | +# PyTorch result: |
| 226 | +# |
| 227 | + |
| 228 | +x.softmax(0) |
| 229 | + |
| 230 | +###################################################################### |
| 231 | +# :class:`MaskedTensor` result: |
| 232 | +# |
| 233 | + |
| 234 | +mt.softmax(0) |
| 235 | + |
| 236 | +###################################################################### |
| 237 | +# `Issue 61474 -- Implementing missing torch.nan* operators <https://github.com/pytorch/pytorch/issues/61474>`_ |
| 238 | +# -------------------------------------------------------------------------------------------------------------- |
| 239 | +# |
| 240 | +# In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications, |
| 241 | +# such as ``torch.nanmax``, ``torch.nanmin``, etc. |
| 242 | +# |
| 243 | +# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional |
| 244 | +# operators, we propose using :class:`MaskedTensor`s instead. Since |
| 245 | +# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point: |
| 246 | +# |
| 247 | + |
| 248 | +x = torch.arange(16).float() |
| 249 | +y = x * x.fmod(4) |
| 250 | +z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros |
| 251 | +print("y:\n, y") |
| 252 | +# z is just y with the zeros replaced with nan's |
| 253 | +print("z:\n", z) |
| 254 | +print("y.mean():\n", y.mean()) |
| 255 | +print("z.nanmean():\n", z.nanmean()) |
| 256 | +# MaskedTensor successfully ignores the 0's |
| 257 | +print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0))) |
| 258 | + |
| 259 | +###################################################################### |
| 260 | +# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring |
| 261 | +# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the |
| 262 | +# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation, |
| 263 | +# and we already have support for the other operations listed in the issue. For example: |
| 264 | +# |
| 265 | + |
| 266 | +torch.argmin(masked_tensor(y, y != 0)) |
| 267 | + |
| 268 | +###################################################################### |
| 269 | +# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1. |
| 270 | +# |
| 271 | +# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent |
| 272 | +# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan`` |
| 273 | +# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result. |
| 274 | +# |
| 275 | + |
| 276 | +x = torch.empty(16).fill_(float('nan')) |
| 277 | +print("x:\n", x) |
| 278 | +print("torch.nanmean(x):\n", torch.nanmean(x)) |
| 279 | +print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x)))) |
| 280 | + |
| 281 | +###################################################################### |
| 282 | +# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value. |
| 283 | +# |
| 284 | +# Conclusion |
| 285 | +# ++++++++++ |
| 286 | +# |
| 287 | +# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their |
| 288 | +# value through a series of examples and issues that they've helped resolve. |
| 289 | +# |
| 290 | +# Further Reading |
| 291 | +# +++++++++++++++ |
| 292 | +# |
| 293 | +# To continue learning more, you can find our |
| 294 | +# `Sparsity tutorial <https://github.com/pytorch/tutorials/pull/2050/files>`_ to see how MaskedTensor enables sparsity |
| 295 | +# and the different storage formats we currently support. |
| 296 | +# |
0 commit comments