Skip to content

Commit 1edc818

Browse files
george-qiSvetlana Karslioglumalfet
authored
[maskedtensor] Overview tutorial [1/4] (#2050)
Add MaksedTensor prototype tutorial Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: Nikita Shulga <[email protected]>
1 parent 6d21237 commit 1edc818

File tree

5 files changed

+344
-2
lines changed

5 files changed

+344
-2
lines changed

.jenkins/build.sh

-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ if [[ "${JOB_BASE_NAME}" == *worker_* ]]; then
4747
# python $DIR/remove_runnable_code.py intermediate_source/spatial_transformer_tutorial.py intermediate_source/spatial_transformer_tutorial.py || true
4848
# Temp remove for 1.10 release.
4949
# python $DIR/remove_runnable_code.py advanced_source/neural_style_tutorial.py advanced_source/neural_style_tutorial.py || true
50-
5150
# TODO: Fix bugs in these tutorials to make them runnable again
5251
# python $DIR/remove_runnable_code.py beginner_source/audio_classifier_tutorial.py beginner_source/audio_classifier_tutorial.py || true
5352

.jenkins/validate_tutorials_built.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"recipes/Captum_Recipe",
5151
"hyperparameter_tuning_tutorial",
5252
"flask_rest_api_tutorial",
53+
"fx_numeric_suite_tutorial", # remove when https://github.com/pytorch/tutorials/pull/2089 is fixed
5354
]
5455

5556

+333
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) MaskedTensor Overview
5+
*********************************
6+
"""
7+
8+
######################################################################
9+
# This tutorial is designed to serve as a starting point for using MaskedTensors
10+
# and discuss its masking semantics.
11+
#
12+
# MaskedTensor serves as an extension to :class:`torch.Tensor` that provides the user with the ability to:
13+
#
14+
# * use any masked semantics (for example, variable length tensors, nan* operators, etc.)
15+
# * differentiation between 0 and NaN gradients
16+
# * various sparse applications (see tutorial below)
17+
#
18+
# For a more detailed introduction on what MaskedTensors are, please find the
19+
# `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__.
20+
#
21+
# Using MaskedTensor
22+
# ==================
23+
#
24+
# In this section we discuss how to use MaskedTensor including how to construct, access, the data
25+
# and mask, as well as indexing and slicing.
26+
#
27+
# Preparation
28+
# -----------
29+
#
30+
# We'll begin by doing the necessary setup for the tutorial:
31+
#
32+
33+
import torch
34+
from torch.masked import masked_tensor, as_masked_tensor
35+
import warnings
36+
37+
# Disable prototype warnings and such
38+
warnings.filterwarnings(action='ignore', category=UserWarning)
39+
40+
######################################################################
41+
# Construction
42+
# ------------
43+
#
44+
# There are a few different ways to construct a MaskedTensor:
45+
#
46+
# * The first way is to directly invoke the MaskedTensor class
47+
# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor`
48+
# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`
49+
#
50+
# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.
51+
#
52+
# Accessing the data and mask
53+
# ---------------------------
54+
#
55+
# The underlying fields in a MaskedTensor can be accessed through:
56+
#
57+
# * the :meth:`MaskedTensor.get_data` function
58+
# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"
59+
# while ``False`` indicates "unspecified" or "invalid".
60+
#
61+
# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that
62+
# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to
63+
# return a Tensor with filled values.
64+
#
65+
# Indexing and slicing
66+
# --------------------
67+
#
68+
# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing
69+
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
70+
#
71+
72+
data = torch.arange(24).reshape(2, 3, 4)
73+
mask = data % 2 == 0
74+
75+
print("data:\n", data)
76+
print("mask:\n", mask)
77+
78+
######################################################################
79+
#
80+
81+
# float is used for cleaner visualization when being printed
82+
mt = masked_tensor(data.float(), mask)
83+
84+
print("mt[0]:\n", mt[0])
85+
print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])
86+
87+
######################################################################
88+
# Why is MaskedTensor useful?
89+
# ===========================
90+
#
91+
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
92+
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
93+
# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.
94+
#
95+
# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today
96+
# and illustrate how :class:`MaskedTensor` can solve these problems.
97+
#
98+
# Distinguishing between 0 and NaN gradient
99+
# -----------------------------------------
100+
#
101+
# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are
102+
# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value
103+
# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading
104+
# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing
105+
# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early
106+
# in the chain of operations a NaN value manifests).
107+
#
108+
# :class:`MaskedTensor` is the perfect solution for this!
109+
#
110+
# torch.where
111+
# ^^^^^^^^^^^
112+
#
113+
# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations
114+
# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0
115+
# or one from undefined gradients. Therefore, we remain consistent and mask out the results:
116+
#
117+
# Current result:
118+
#
119+
120+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
121+
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
122+
y.sum().backward()
123+
x.grad
124+
125+
######################################################################
126+
# :class:`MaskedTensor` result:
127+
#
128+
129+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
130+
mask = x < 0
131+
mx = masked_tensor(x, mask, requires_grad=True)
132+
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
133+
y = torch.where(mask, torch.exp(mx), my)
134+
y.sum().backward()
135+
mx.grad
136+
137+
######################################################################
138+
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
139+
# to mask out elements instead of setting them to zero.
140+
#
141+
# Another torch.where
142+
# ^^^^^^^^^^^^^^^^^^^
143+
#
144+
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.
145+
#
146+
# Current result:
147+
#
148+
149+
a = torch.randn((), requires_grad=True)
150+
b = torch.tensor(False)
151+
c = torch.ones(())
152+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
153+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
154+
155+
######################################################################
156+
# :class:`MaskedTensor` result:
157+
#
158+
159+
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
160+
b = torch.tensor(False)
161+
c = torch.ones(())
162+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
163+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
164+
165+
######################################################################
166+
# This issue is similar (and even links to the next issue below) in that it expresses frustration with
167+
# unexpected behavior because of the inability to differentiate "no gradient" vs "zero gradient",
168+
# which in turn makes working with other ops difficult to reason about.
169+
#
170+
# When using mask, x/0 yields NaN grad
171+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
172+
#
173+
# In `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__, the user proposes that
174+
# `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,
175+
# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.
176+
#
177+
# Current result:
178+
#
179+
180+
x = torch.tensor([1., 1.], requires_grad=True)
181+
div = torch.tensor([0., 1.])
182+
y = x/div # => y is [inf, 1]
183+
mask = (div != 0) # => mask is [0, 1]
184+
y[mask].backward()
185+
x.grad
186+
187+
######################################################################
188+
# :class:`MaskedTensor` result:
189+
#
190+
191+
x = torch.tensor([1., 1.], requires_grad=True)
192+
div = torch.tensor([0., 1.])
193+
y = x/div # => y is [inf, 1]
194+
mask = (div != 0) # => mask is [0, 1]
195+
loss = as_masked_tensor(y, mask)
196+
loss.sum().backward()
197+
x.grad
198+
199+
######################################################################
200+
# :func:`torch.nansum` and :func:`torch.nanmean`
201+
# ----------------------------------------------
202+
#
203+
# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,
204+
# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.
205+
#
206+
# Current result:
207+
#
208+
209+
a = torch.tensor([1., 2., float('nan')])
210+
b = torch.tensor(1.0, requires_grad=True)
211+
c = a * b
212+
c1 = torch.nansum(c)
213+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
214+
bgrad1
215+
216+
######################################################################
217+
# :class:`MaskedTensor` result:
218+
#
219+
220+
a = torch.tensor([1., 2., float('nan')])
221+
b = torch.tensor(1.0, requires_grad=True)
222+
mt = masked_tensor(a, ~torch.isnan(a))
223+
c = mt * b
224+
c1 = torch.sum(c)
225+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
226+
bgrad1
227+
228+
######################################################################
229+
# Safe Softmax
230+
# ------------
231+
#
232+
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__
233+
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
234+
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
235+
# then this will result in NaNs, which can lead to training divergence.
236+
#
237+
# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:
238+
#
239+
240+
data = torch.randn(3, 3)
241+
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
242+
x = data.masked_fill(~mask, float('-inf'))
243+
mt = masked_tensor(data, mask)
244+
print("x:\n", x)
245+
print("mt:\n", mt)
246+
247+
######################################################################
248+
# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely
249+
# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.
250+
# However, what we would really like is for the gradients to be masked out since they are unspecified and would be
251+
# invalid for training.
252+
#
253+
# PyTorch result:
254+
#
255+
256+
x.softmax(0)
257+
258+
######################################################################
259+
# :class:`MaskedTensor` result:
260+
#
261+
262+
mt.softmax(0)
263+
264+
######################################################################
265+
# Implementing missing torch.nan* operators
266+
# -----------------------------------------
267+
#
268+
# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__,
269+
# there is a request to add additional operators to cover the various `torch.nan*` applications,
270+
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
271+
#
272+
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
273+
# operators, we propose using :class:`MaskedTensor` instead.
274+
# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__,
275+
# we can use it as a comparison point:
276+
#
277+
278+
x = torch.arange(16).float()
279+
y = x * x.fmod(4)
280+
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
281+
282+
######################################################################
283+
#
284+
print("y:\n", y)
285+
# z is just y with the zeros replaced with nan's
286+
print("z:\n", z)
287+
288+
######################################################################
289+
#
290+
291+
print("y.mean():\n", y.mean())
292+
print("z.nanmean():\n", z.nanmean())
293+
# MaskedTensor successfully ignores the 0's
294+
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
295+
296+
######################################################################
297+
# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring
298+
# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the
299+
# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,
300+
# and we already have support for the other operations listed in the issue. For example:
301+
#
302+
303+
torch.argmin(masked_tensor(y, y != 0))
304+
305+
######################################################################
306+
# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.
307+
#
308+
# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
309+
# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
310+
# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
311+
#
312+
313+
x = torch.empty(16).fill_(float('nan'))
314+
print("x:\n", x)
315+
print("torch.nanmean(x):\n", torch.nanmean(x))
316+
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
317+
318+
######################################################################
319+
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
320+
#
321+
# Conclusion
322+
# ==========
323+
#
324+
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
325+
# value through a series of examples and issues that they've helped resolve.
326+
#
327+
# Further Reading
328+
# ===============
329+
#
330+
# To continue learning more, you can find our
331+
# `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__
332+
# to see how MaskedTensor enables sparsity and the different storage formats we currently support.
333+
#

prototype_source/prototype_index.rst

+10
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor Overview
148+
:card_description: Learn about masked tensors, the source of truth for specified and unspecified values
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_overview.html
151+
:tags: MaskedTensor
152+
144153
.. End of tutorial card section
145154
146155
.. raw:: html
@@ -172,3 +181,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172181
prototype/vmap_recipe.html
173182
prototype/vulkan_workflow.html
174183
prototype/nestedtensor.html
184+
prototype/maskedtensor_overview.html

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ torchvision
1313
torchtext
1414
torchaudio
1515
torchdata
16-
functorch>=0.2.1
1716
networkx
1817
PyHamcrest
1918
bs4

0 commit comments

Comments
 (0)