Skip to content

Commit d675b72

Browse files
committed
[maskedtensor] Advanced semantics [4/4]
1 parent 12ea814 commit d675b72

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) MaskedTensor Advanced Semantics
5+
===========================================
6+
"""
7+
8+
######################################################################
9+
#
10+
# Before working on this tutorial, please make sure to review our
11+
# `MaskedTensor Overview tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`.
12+
#
13+
# The purpose of this tutorial is to help users understand how some of the advanced semantics work
14+
# and how they came to be. We will focus on two particular ones:
15+
#
16+
# *. Differences between MaskedTensor and `NumPy's MaskedArray <https://numpy.org/doc/stable/reference/maskedarray.html>`__
17+
# *. Reduction semantics
18+
#
19+
# Preparation
20+
# -----------
21+
#
22+
23+
import torch
24+
from torch.masked import masked_tensor
25+
import numpy as np
26+
27+
######################################################################
28+
# MaskedTensor vs NumPy's MaskedArray
29+
# -----------------------------------
30+
#
31+
# NumPy's ``MaskedArray`` has a few fundamental semantics differences from MaskedTensor.
32+
#
33+
# *. Their factory function and basic definition inverts the mask (similar to ``torch.nn.MHA``); that is, MaskedTensor
34+
# uses ``True`` to denote "specified" and ``False`` to denote "unspecified", or "valid"/"invalid",
35+
# whereas NumPy does the opposite. We believe that our mask definition is not only more intuitive,
36+
# but it also aligns more with the existing semantics in PyTorch as a whole.
37+
# *. Intersection semantics. In NumPy, if one of two elements are masked out, the resulting element will be
38+
# masked out as well -- in practice, they
39+
# `apply the logical_or operator <https://github.com/numpy/numpy/blob/68299575d8595d904aff6f28e12d21bf6428a4ba/numpy/ma/core.py#L1016-L1024>`__.
40+
#
41+
42+
data = torch.arange(5.)
43+
mask = torch.tensor([True, True, False, True, False])
44+
npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
45+
npm1 = np.ma.masked_array(data.numpy(), (mask).numpy())
46+
print("npm0:\n", npm0)
47+
print("npm1:\n", npm1)
48+
print("npm0 + npm1:\n", npm0 + npm1)
49+
50+
######################################################################
51+
# Meanwhile, MaskedTensor does not support addition or binary operators with masks that don't match --
52+
# to understand why, please find the :ref:`section on reductions <reduction-semantics>`.
53+
#
54+
55+
mt0 = masked_tensor(data, mask)
56+
mt1 = masked_tensor(data, ~mask)
57+
print("mt0:\n", mt0)
58+
print("mt1:\n", mt1)
59+
60+
try:
61+
mt0 + mt1
62+
except ValueError as e:
63+
print (e)
64+
65+
######################################################################
66+
# However, if this behavior is desired, MaskedTensor does support these semantics by giving access to the data and masks
67+
# and conveniently converting a MaskedTensor to a Tensor with masked values filled in using :func:`to_tensor`.
68+
# For example:
69+
#
70+
71+
t0 = mt0.to_tensor(0)
72+
t1 = mt1.to_tensor(0)
73+
mt2 = masked_tensor(t0 + t1, mt0.get_mask() & mt1.get_mask())
74+
print("t0:\n", t0)
75+
print("t1:\n", t1)
76+
print("mt2 (t0 + t1):\n", mt2)
77+
78+
######################################################################
79+
# Note that the mask is `mt0.get_mask() & mt1.get_mask()` since :class:`MaskedTensor`'s mask is the inverse of NumPy's.
80+
#
81+
# .. _reduction-semantics:
82+
#
83+
# Reduction Semantics
84+
# -------------------
85+
#
86+
# Recall in `MaskedTensor's Overview tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`__
87+
# we discussed "Implementing missing torch.nan* ops". Those are examples of reductions -- operators that remove one
88+
# (or more) dimensions from a Tensor and then aggregate the result. In this section, we will use reduction semantics
89+
# to motivate our strict requirements around matching masks from above.
90+
#
91+
# Fundamentally, :class:`MaskedTensor`s perform the same reduction operation while ignoring the masked out
92+
# (unspecified) values. By way of example:
93+
#
94+
95+
data = torch.arange(12, dtype=torch.float).reshape(3, 4)
96+
mask = torch.randint(2, (3, 4), dtype=torch.bool)
97+
mt = masked_tensor(data, mask)
98+
print("data:\n", data)
99+
print("mask:\n", mask)
100+
print("mt:\n", mt)
101+
102+
######################################################################
103+
# Now, the different reductions (all on dim=1):
104+
#
105+
106+
print("torch.sum:\n", torch.sum(mt, 1))
107+
print("torch.mean:\n", torch.mean(mt, 1))
108+
print("torch.prod:\n", torch.prod(mt, 1))
109+
print("torch.amin:\n", torch.amin(mt, 1))
110+
print("torch.amax:\n", torch.amax(mt, 1))
111+
112+
######################################################################
113+
# Of note, the value under a masked out element is not guaranteed to have any specific value, especially if the
114+
# row or column is entirely masked out (the same is true for normalizations).
115+
# For more details on masked semantics, you can find this `RFC <https://github.com/pytorch/rfcs/pull/27>`__.
116+
#
117+
# Now, we can revisit the question: why do we enforce the invariant that masks must match for binary operators?
118+
# In other words, why don't we use the same semantics as ``np.ma.masked_array``? Consider the following example:
119+
#
120+
121+
data0 = torch.arange(10.).reshape(2, 5)
122+
data1 = torch.arange(10.).reshape(2, 5) + 10
123+
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
124+
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])
125+
npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy())
126+
npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy())
127+
print("npm0:", npm0)
128+
print("npm1:", npm1)
129+
130+
######################################################################
131+
# Now, let's try addition:
132+
#
133+
134+
print("(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
135+
print("npm0.sum(0) + npm1.sum(0):\n", npm0.sum(0) + npm1.sum(0))
136+
137+
######################################################################
138+
# Sum and addition should clearly be associative, but with NumPy's semantics, they are not,
139+
# which can certainly be confusing for the user.
140+
#
141+
# :class:`MaskedTensor`, on the other hand, will simply not allow this operation since `mask0 != mask1`.
142+
# That being said, if the user wishes, there are ways around this
143+
# (for example, filling in the MaskedTensor's undefined elements with 0 values using :func:`to_tensor`
144+
# like shown below), but the user must now be more explicit with their intentions.
145+
#
146+
147+
mt0 = masked_tensor(data0, ~mask0)
148+
mt1 = masked_tensor(data1, ~mask1)
149+
(mt0.to_tensor(0) + mt1.to_tensor(0)).sum(0)
150+
151+
######################################################################
152+
# Conclusion
153+
# ----------
154+
#
155+
# In this tutorial, we have learned about the different design decisions behind MaskedTensor and
156+
# NumPy's MaskedArray, as well as reduction semantics.
157+
# In general, MaskedTensor is designed to avoid ambiguity and confusing semantics (for example, we try to preserve
158+
# the associative property amongst binary operations), which in turn can necessitate the user
159+
# to be more intentional with their code at times, but we believe this to be the better move.
160+
# If you have any thoughts on this, please `let us know <https://github.com/pytorch/pytorch/issues>`__!
161+
#

prototype_source/prototype_index.rst

+10
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: Masked Tensor Advanced Semantics
148+
:card_description: Learn more about Masked Tensor's advanced semantics (reductions and comparing vs. NumPy's MaskedArray)
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_advanced_semantics.html
151+
:tags: MaskedTensor
152+
144153
.. End of tutorial card section
145154
146155
.. raw:: html
@@ -172,3 +181,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172181
prototype/vmap_recipe.html
173182
prototype/vulkan_workflow.html
174183
prototype/nestedtensor.html
184+
prototype/maskedtensor_advanced_semantics.html

0 commit comments

Comments
 (0)