Skip to content

Commit 0ed0c6a

Browse files
committed
[maskedtensor] Overview tutorial [1/4]
1 parent 04e1ba9 commit 0ed0c6a

File tree

2 files changed

+311
-0
lines changed

2 files changed

+311
-0
lines changed
+301
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) MaskedTensor Overview
5+
=====================
6+
**Author**: `George Qi <https://github.com/george-qi>`_
7+
"""
8+
9+
######################################################################
10+
# This tutorial is designed to serve as a starting point for using MaskedTensors
11+
# and discuss its masking semantics.
12+
#
13+
14+
######################################################################
15+
# Using MaskedTensor
16+
# ++++++++++++++++++
17+
#
18+
# Construction
19+
# ------------
20+
#
21+
# There are a few different ways to construct a MaskedTensor:
22+
#
23+
# * The first way is to directly invoke the MaskedTensor class
24+
# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor` factory functions,
25+
# which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`
26+
#
27+
# .. autosummary::
28+
# :toctree: generated
29+
# :nosignatures:
30+
#
31+
# masked.masked_tensor
32+
# masked.as_masked_tensor
33+
#
34+
# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.
35+
#
36+
# Accessing the data and mask
37+
# ---------------------------
38+
#
39+
# The underlying fields in a MaskedTensor can be accessed through:
40+
#
41+
# * the :meth:`MaskedTensor.get_data` function
42+
# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"
43+
# while ``False`` indicates "unspecified" or "invalid".
44+
#
45+
# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that
46+
# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to
47+
# return a Tensor with filled values.
48+
#
49+
# Indexing and slicing
50+
# --------------------
51+
#
52+
# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing
53+
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
54+
55+
import torch
56+
from torch.masked import masked_tensor
57+
58+
data = torch.arange(24).reshape(2, 3, 4)
59+
mask = data % 2 == 0
60+
61+
print("data\n", data)
62+
print("mask\n", mask)
63+
64+
# float is used for cleaner visualization when being printed
65+
mt = masked_tensor(data.float(), mask)
66+
67+
print ("mt[0]:\n", mt[0])
68+
print ("mt[:, :, 2:4]", mt[:, :, 2:4])
69+
70+
######################################################################
71+
# Why is MaskedTensor useful?
72+
# +++++++++++++++++++++++++++
73+
#
74+
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
75+
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
76+
# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.
77+
#
78+
# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today
79+
# and illustrate how :class:`MaskedTensor` can solve these problems.
80+
#
81+
# Distinguishing between 0 and NaN gradient
82+
# -----------------------------------------
83+
#
84+
# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are
85+
# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value
86+
# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading
87+
# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing
88+
# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early
89+
# in the chain of operations a NaN value manifests).
90+
#
91+
# :class:`MaskedTensor` is the perfect solution for this!
92+
#
93+
# `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`_ -- :func:`torch.where`
94+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
95+
#
96+
# Current result:
97+
98+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
99+
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
100+
y.sum().backward()
101+
x.grad
102+
103+
######################################################################
104+
# :class:`MaskedTensor` result:
105+
#
106+
107+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
108+
mask = x < 0
109+
mx = masked_tensor(x, mask, requires_grad=True)
110+
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
111+
y = torch.where(mask, torch.exp(mx), my)
112+
y.sum().backward()
113+
mx.grad
114+
115+
######################################################################
116+
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
117+
# to mask out elements instead of setting them to zero.
118+
#
119+
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`_ -- another :func:`torch.where`
120+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
121+
#
122+
# Current result:
123+
#
124+
125+
a = torch.randn((), requires_grad=True)
126+
b = torch.tensor(False)
127+
c = torch.ones(())
128+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
129+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
130+
131+
######################################################################
132+
# :class:`MaskedTensor` result:
133+
#
134+
135+
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
136+
b = torch.tensor(False)
137+
c = torch.ones(())
138+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
139+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
140+
141+
######################################################################
142+
# This issue is similar (and even links to the next issue below) in that it expresses frustration with unexpected behavior
143+
# because of the inability to differentiate "no gradient" vs "zero gradient", which in turn makes
144+
# working with other ops difficult and unreliable.
145+
#
146+
# `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`_ -- when using mask, x/0 yields NaN grad
147+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148+
#
149+
# Current result:
150+
#
151+
152+
x = torch.tensor([1., 1.], requires_grad=True)
153+
div = torch.tensor([0., 1.])
154+
y = x/div # => y is [inf, 1]
155+
mask = (div != 0) # => mask is [0, 1]
156+
y[mask].backward()
157+
x.grad
158+
tensor([nan, 1.])
159+
160+
######################################################################
161+
# :class:`MaskedTensor` result:
162+
#
163+
164+
x = torch.tensor([1., 1.], requires_grad=True)
165+
div = torch.tensor([0., 1.])
166+
y = x/div # => y is [inf, 1]
167+
>>>
168+
mask = (div != 0) # => mask is [0, 1]
169+
loss = as_masked_tensor(y, mask)
170+
loss.sum().backward()
171+
x.grad
172+
173+
######################################################################
174+
# Linked in the issue above, this issue proposes that `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,
175+
# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.
176+
#
177+
# `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`_ -- :func:`torch.nansum` and :func:`torch.nanmean`
178+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
179+
#
180+
# Current result:
181+
#
182+
183+
a = torch.tensor([1., 2., float('nan')])
184+
b = torch.tensor(1.0, requires_grad=True)
185+
c = a * b
186+
c1 = torch.nansum(c)
187+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
188+
bgrad1
189+
190+
######################################################################
191+
# :class:`MaskedTensor` result:
192+
#
193+
194+
a = torch.tensor([1., 2., float('nan')])
195+
b = torch.tensor(1.0, requires_grad=True)
196+
mt = masked_tensor(a, ~torch.isnan(a))
197+
c = mt * b
198+
c1 = torch.sum(c)
199+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
200+
bgrad1
201+
202+
######################################################################
203+
# Here, the gradient doesn't even calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles
204+
# it correctly.
205+
#
206+
# Safe Softmax
207+
# ------------
208+
#
209+
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_
210+
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
211+
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
212+
# then this will result in NaNs, which can lead to training divergence.
213+
#
214+
# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:
215+
#
216+
217+
data = torch.randn(3, 3)
218+
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
219+
x = data.masked_fill(~mask, float('-inf'))
220+
mt = masked_tensor(data, mask)
221+
print("x:\n", x)
222+
print("mt:\n", mt)
223+
224+
######################################################################
225+
# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely
226+
# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.
227+
# However, what we would really like is for the gradients to be masked out since they are unspecified and would be
228+
# invalid for training.
229+
#
230+
# PyTorch result:
231+
#
232+
233+
x.softmax(0)
234+
235+
######################################################################
236+
# :class:`MaskedTensor` result:
237+
#
238+
239+
mt.softmax(0)
240+
241+
######################################################################
242+
# `Issue 61474 -- Implementing missing torch.nan* operators <https://github.com/pytorch/pytorch/issues/61474>`_
243+
# --------------------------------------------------------------------------------------------------------------
244+
#
245+
# In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications,
246+
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
247+
#
248+
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
249+
# operators, we propose using :class:`MaskedTensor`s instead. Since
250+
# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point:
251+
#
252+
253+
x = torch.arange(16).float()
254+
y = x * x.fmod(4)
255+
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
256+
print("y:\n, y")
257+
# z is just y with the zeros replaced with nan's
258+
print("z:\n", z)
259+
print("y.mean():\n", y.mean())
260+
print("z.nanmean():\n", z.nanmean())
261+
# MaskedTensor successfully ignores the 0's
262+
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
263+
264+
######################################################################
265+
# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring
266+
# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the
267+
# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,
268+
# and we already have support for the other operations listed in the issue. For example:
269+
#
270+
271+
torch.argmin(masked_tensor(y, y != 0))
272+
273+
######################################################################
274+
# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.
275+
#
276+
# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
277+
# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
278+
# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
279+
#
280+
281+
x = torch.empty(16).fill_(float('nan'))
282+
print("x:\n", x)
283+
print("torch.nanmean(x):\n", torch.nanmean(x))
284+
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
285+
286+
######################################################################
287+
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
288+
#
289+
# Conclusion
290+
# ++++++++++
291+
#
292+
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
293+
# value through a series of examples and issues that they've helped resolve.
294+
#
295+
# Further Reading
296+
# +++++++++++++++
297+
#
298+
# To continue learning more, you can find our
299+
# `Sparsity tutorial <https://github.com/pytorch/tutorials/pull/2050/files>`_ to see how MaskedTensor enables sparsity
300+
# and the different storage formats we currently support.
301+
#

prototype_source/prototype_index.rst

+10
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor Overview
148+
:card_description: Learn about masked tensors, the source of truth for specified and unspecified values
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_overview.html
151+
:tags: MaskedTensor
152+
144153
.. End of tutorial card section
145154
146155
.. raw:: html
@@ -172,3 +181,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172181
prototype/vmap_recipe.html
173182
prototype/vulkan_workflow.html
174183
prototype/nestedtensor.html
184+
prototype/maskedtensor_overview.html

0 commit comments

Comments
 (0)