Skip to content

Commit 42fb3fa

Browse files
committed
[maskedtensor] Overview tutorial [1/4]
1 parent 04e1ba9 commit 42fb3fa

File tree

2 files changed

+306
-0
lines changed

2 files changed

+306
-0
lines changed
+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) MaskedTensor Overview
5+
=====================
6+
**Author**: `George Qi <https://github.com/george-qi>`_
7+
"""
8+
9+
######################################################################
10+
# This tutorial is designed to serve as a starting point for using MaskedTensors
11+
# and discuss its masking semantics.
12+
#
13+
14+
######################################################################
15+
# Using MaskedTensor
16+
# ++++++++++++++++++
17+
#
18+
# Construction
19+
# ------------
20+
#
21+
# There are a few different ways to construct a MaskedTensor:
22+
#
23+
# * The first way is to directly invoke the MaskedTensor class
24+
# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor`
25+
# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`
26+
#
27+
# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.
28+
#
29+
# Accessing the data and mask
30+
# ---------------------------
31+
#
32+
# The underlying fields in a MaskedTensor can be accessed through:
33+
#
34+
# * the :meth:`MaskedTensor.get_data` function
35+
# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"
36+
# while ``False`` indicates "unspecified" or "invalid".
37+
#
38+
# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that
39+
# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to
40+
# return a Tensor with filled values.
41+
#
42+
# Indexing and slicing
43+
# --------------------
44+
#
45+
# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing
46+
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
47+
#
48+
49+
import torch
50+
from torch.masked import masked_tensor
51+
52+
data = torch.arange(24).reshape(2, 3, 4)
53+
mask = data % 2 == 0
54+
55+
print("data\n", data)
56+
print("mask\n", mask)
57+
58+
# float is used for cleaner visualization when being printed
59+
mt = masked_tensor(data.float(), mask)
60+
61+
print ("mt[0]:\n", mt[0])
62+
print ("mt[:, :, 2:4]", mt[:, :, 2:4])
63+
64+
######################################################################
65+
# Why is MaskedTensor useful?
66+
# +++++++++++++++++++++++++++
67+
#
68+
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
69+
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
70+
# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.
71+
#
72+
# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today
73+
# and illustrate how :class:`MaskedTensor` can solve these problems.
74+
#
75+
# Distinguishing between 0 and NaN gradient
76+
# -----------------------------------------
77+
#
78+
# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are
79+
# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value
80+
# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading
81+
# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing
82+
# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early
83+
# in the chain of operations a NaN value manifests).
84+
#
85+
# :class:`MaskedTensor` is the perfect solution for this!
86+
#
87+
# `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`_ -- :func:`torch.where`
88+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
89+
#
90+
# Current result:
91+
#
92+
93+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
94+
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
95+
y.sum().backward()
96+
x.grad
97+
98+
######################################################################
99+
# :class:`MaskedTensor` result:
100+
#
101+
102+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
103+
mask = x < 0
104+
mx = masked_tensor(x, mask, requires_grad=True)
105+
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
106+
y = torch.where(mask, torch.exp(mx), my)
107+
y.sum().backward()
108+
mx.grad
109+
110+
######################################################################
111+
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
112+
# to mask out elements instead of setting them to zero.
113+
#
114+
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`_ -- another :func:`torch.where`
115+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
116+
#
117+
# Current result:
118+
#
119+
120+
a = torch.randn((), requires_grad=True)
121+
b = torch.tensor(False)
122+
c = torch.ones(())
123+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
124+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
125+
126+
######################################################################
127+
# :class:`MaskedTensor` result:
128+
#
129+
130+
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
131+
b = torch.tensor(False)
132+
c = torch.ones(())
133+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
134+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
135+
136+
######################################################################
137+
# This issue is similar (and even links to the next issue below) in that it expresses frustration with unexpected behavior
138+
# because of the inability to differentiate "no gradient" vs "zero gradient", which in turn makes
139+
# working with other ops difficult and unreliable.
140+
#
141+
# `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`_ -- when using mask, x/0 yields NaN grad
142+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
143+
#
144+
# Current result:
145+
#
146+
147+
x = torch.tensor([1., 1.], requires_grad=True)
148+
div = torch.tensor([0., 1.])
149+
y = x/div # => y is [inf, 1]
150+
mask = (div != 0) # => mask is [0, 1]
151+
y[mask].backward()
152+
x.grad
153+
tensor([nan, 1.])
154+
155+
######################################################################
156+
# :class:`MaskedTensor` result:
157+
#
158+
159+
x = torch.tensor([1., 1.], requires_grad=True)
160+
div = torch.tensor([0., 1.])
161+
y = x/div # => y is [inf, 1]
162+
>>>
163+
mask = (div != 0) # => mask is [0, 1]
164+
loss = as_masked_tensor(y, mask)
165+
loss.sum().backward()
166+
x.grad
167+
168+
######################################################################
169+
# Linked in the issue above, this issue proposes that `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,
170+
# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.
171+
#
172+
# `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`_ -- :func:`torch.nansum` and :func:`torch.nanmean`
173+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
174+
#
175+
# Current result:
176+
#
177+
178+
a = torch.tensor([1., 2., float('nan')])
179+
b = torch.tensor(1.0, requires_grad=True)
180+
c = a * b
181+
c1 = torch.nansum(c)
182+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
183+
bgrad1
184+
185+
######################################################################
186+
# :class:`MaskedTensor` result:
187+
#
188+
189+
a = torch.tensor([1., 2., float('nan')])
190+
b = torch.tensor(1.0, requires_grad=True)
191+
mt = masked_tensor(a, ~torch.isnan(a))
192+
c = mt * b
193+
c1 = torch.sum(c)
194+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
195+
bgrad1
196+
197+
######################################################################
198+
# Here, the gradient doesn't even calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles
199+
# it correctly.
200+
#
201+
# Safe Softmax
202+
# ------------
203+
#
204+
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_
205+
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
206+
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
207+
# then this will result in NaNs, which can lead to training divergence.
208+
#
209+
# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:
210+
#
211+
212+
data = torch.randn(3, 3)
213+
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
214+
x = data.masked_fill(~mask, float('-inf'))
215+
mt = masked_tensor(data, mask)
216+
print("x:\n", x)
217+
print("mt:\n", mt)
218+
219+
######################################################################
220+
# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely
221+
# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.
222+
# However, what we would really like is for the gradients to be masked out since they are unspecified and would be
223+
# invalid for training.
224+
#
225+
# PyTorch result:
226+
#
227+
228+
x.softmax(0)
229+
230+
######################################################################
231+
# :class:`MaskedTensor` result:
232+
#
233+
234+
mt.softmax(0)
235+
236+
######################################################################
237+
# `Issue 61474 -- Implementing missing torch.nan* operators <https://github.com/pytorch/pytorch/issues/61474>`_
238+
# --------------------------------------------------------------------------------------------------------------
239+
#
240+
# In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications,
241+
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
242+
#
243+
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
244+
# operators, we propose using :class:`MaskedTensor`s instead. Since
245+
# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point:
246+
#
247+
248+
x = torch.arange(16).float()
249+
y = x * x.fmod(4)
250+
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
251+
print("y:\n, y")
252+
# z is just y with the zeros replaced with nan's
253+
print("z:\n", z)
254+
print("y.mean():\n", y.mean())
255+
print("z.nanmean():\n", z.nanmean())
256+
# MaskedTensor successfully ignores the 0's
257+
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
258+
259+
######################################################################
260+
# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring
261+
# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the
262+
# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,
263+
# and we already have support for the other operations listed in the issue. For example:
264+
#
265+
266+
torch.argmin(masked_tensor(y, y != 0))
267+
268+
######################################################################
269+
# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.
270+
#
271+
# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
272+
# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
273+
# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
274+
#
275+
276+
x = torch.empty(16).fill_(float('nan'))
277+
print("x:\n", x)
278+
print("torch.nanmean(x):\n", torch.nanmean(x))
279+
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
280+
281+
######################################################################
282+
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
283+
#
284+
# Conclusion
285+
# ++++++++++
286+
#
287+
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
288+
# value through a series of examples and issues that they've helped resolve.
289+
#
290+
# Further Reading
291+
# +++++++++++++++
292+
#
293+
# To continue learning more, you can find our
294+
# `Sparsity tutorial <https://github.com/pytorch/tutorials/pull/2050/files>`_ to see how MaskedTensor enables sparsity
295+
# and the different storage formats we currently support.
296+
#

prototype_source/prototype_index.rst

+10
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor Overview
148+
:card_description: Learn about masked tensors, the source of truth for specified and unspecified values
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_overview.html
151+
:tags: MaskedTensor
152+
144153
.. End of tutorial card section
145154
146155
.. raw:: html
@@ -172,3 +181,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172181
prototype/vmap_recipe.html
173182
prototype/vulkan_workflow.html
174183
prototype/nestedtensor.html
184+
prototype/maskedtensor_overview.html

0 commit comments

Comments
 (0)