Skip to content

[maskedtensor] Add adagrad sparse semantics tutorial #2047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions beginner_source/maskedtensor_adagrad_sparse_semantics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-

"""
Efficiency of writing "sparse" semantics for Adagrad
====================================================

`Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
that were introduce while writing "sparse" semantics for Adagrad.
But really the code doesn't use sparsity as a compression and optimization technique,
it wants to use masked semantics. We worked around this by introducing one-off semantics and operators
that encode this behavior while forcing users to be aware of storage details such as indices and values.

In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
and when it is just used to compress zeros.
We'll also compare and contrast this with equivalent code written using MaskedTensor.
In the end the code snippets are repeat without additional comments to show the difference in brevity.

""""

import torch
from torch.masked.maskedtensor import masked_tensor

######################################################################
# Original sparse implementation
# ------------------------------
#
# First, let's look at the current implementation of
# `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
#

def _make_sparse(grad, grad_indices, values):
size = grad.size()
if grad_indices.numel() == 0 or values.numel() == 0:
return torch.empty_like(grad)
return torch.sparse_coo_tensor(grad_indices, values, size)

# Some hyperparameters
eps = 1e-10
clr = 0.1

# We don't support sparse gradients
param = torch.arange(8).reshape(2, 4).float()
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)
grad = torch.sparse_coo_tensor(i, v, [2, 4])
state_sum = torch.full_like(param, 0.5) # initial value for state sum

print("param:\n", param)
print("grad:\n", grad.to_dense())
print("state_sum:\n", state_sum)

######################################################################
#

state_sum = torch.full_like(param, 0.5) # initial value for state sum
print(state_sum)

grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()

# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
# We take care to make std sparse, even though state_sum clearly is not.
# This means that we're only applying the gradient to parts of the state_sum
# for which it is specified. This even drives the point home a lot more that
# the passed gradient is not sparse, but masked.
std = state_sum.sparse_mask(grad)
print("state_sum:\n", state_sum)
print("std:\n", std.to_dense())

######################################################################
# This is where we have a very important divergence.
# The addition of eps should technically be applied to all values, but instead is only applied to specified values.
# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
# If parts of the values of the gradient are zero they are still included if materialized.
# Even though they could be compressed by other sparse storage layouts.
# This is technically quite brittle even though someone could argue that eps is always very small.
#
# Moreover, an implementation add_ for sparsity as a storage layout and compression scheme should cause densification,
# but we force it not to.
# For this one-off case it is fine until we want to introduce new compression schemes
# such as CSR, BSR or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each
# and write variations for gradients compressed using different storage formats.
#

# We currently dodge all these concerns using the private method values.
std_values = std._values().sqrt_().add_(eps)

# We currently don't support div for sparse Tensors because zero / zero is
# not well defined. For a MaskedTensor undefined / undefined is undefined.
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
print("param:\n", param)

######################################################################
# MaskedTensor sparse implementation
# ----------------------------------
#
# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
# MaskedTensor proposes to call the semantic extension through sparsity masked.
# Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while
# MaskedTensor fixes that because it separates the storage from the semantics.
# Consider the above example using a masked gradient:
#

# Create an entirely new set of parameters to avoid errors
param2 = torch.arange(8).reshape(2, 4).float()
state_sum2 = torch.full_like(param, 0.5) # initial value for state sum

mask = (grad.to_dense() != 0).to_sparse()
masked_grad = masked_tensor(grad, mask)
print("masked_grad:\n", masked_grad)

######################################################################
#

state_sum2 = state_sum2 + masked_grad.pow(2).data()
std2 = masked_tensor(state_sum2.to_sparse(), mask)

# Let's print both this version and the regular version for easier comparison
print("state_sum:\n", state_sum)
print("std:\n", std)
print("state_sum2:\n", state_sum2)
print("std2:\n", std2)

######################################################################
#

# We can add support for in-place operations later. Notice how this doesn't
# need to access any storage internals and is in general a lot shorter
std2 = std2.sqrt().add(eps)

print("std:\n", std)
print("std2:\n", std2)

# .data() indeed returns a sparse tensor
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)
print("param2:\n", param2)

######################################################################
# Conclusion: Difference in code
# ------------------------------
#
# For reference, this is the regular, dense code path without masked gradients or sparsity:
# ::
#
# state_sum.addcmul_(grad, grad, value=1)
# std = state_sum.sqrt().add_(eps)
# param.addcdiv_(grad, std, value=-clr)
#
# The vanilla tensor implementation for sparse is:
#

grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()

state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
std = state_sum.sparse_mask(grad)
std_values = std._values().sqrt_().add_(eps)
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)

######################################################################
# while MaskedTensor minimizes the code to the snippet:
#

state_sum2 = state_sum2 + masked_grad.pow(2).data()
std2 = masked_tensor(state_sum2.to_sparse(), mask)
std2 = std2.sqrt().add(eps)
param2 = param2.add((masked_grad / std2).data(), alpha=-clr)

######################################################################
# And for good measure, let's make sure the parameters match:
#

print("param:\n", param)
print("param2:\n", param2)

108 changes: 108 additions & 0 deletions beginner_source/maskedtensor_distinguish_gradient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
Distinguishing between 0 and NaN gradient
-----------------------------------------

One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are not
defined (NaN) vs. gradients that are actually 0. By way of example, below are several different issues where
:class:`MaskedTensor` can resolve and/or work around the NaN gradient problem.

`Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__ -- torch.where
--------------------------------------------------------------------------------

Current result:

>>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
>>> y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
>>> y.sum().backward()
>>> x.grad
tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, nan, nan])

:class:`MaskedTensor` result:

>>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
>>> mask = x < 0
>>> mx = masked_tensor(x, mask, requires_grad=True)
>>> my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
>>> y = torch.where(mask, torch.exp(mx), my)
>>> y.sum().backward()
>>> mx.grad
MaskedTensor(
[ 0.0000, 0.0067, --, --, --, --, --, --, --, --, --]
)

The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
to mask out elements instead of setting them to zero.

`Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ -- another torch.where
----------------------------------------------------------------------------------------

Current result:

>>> a = torch.randn((), requires_grad=True)
>>> b = torch.tensor(False)
>>> c = torch.ones(())
>>> torch.where(b, a/0, c)
tensor(1., grad_fn=<WhereBackward0>)
>>> torch.autograd.grad(torch.where(b, a/0, c), a)
(tensor(nan),)

:class:`MaskedTensor` result:

>>> a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
>>> b = torch.tensor(False)
>>> c = torch.ones(())
>>> torch.where(b, a/0, c)
MaskedTensor( 1.0000, True)
>>> torch.autograd.grad(torch.where(b, a/0, c), a)
(MaskedTensor(--, False),)

`Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__ -- :func:`torch.nansum` and :func:`torch.nanmean`
-------------------------------------------------------------------------------------------------------------------

Current result:

>>> a = torch.tensor([1., 2., float('nan')])
>>> b = torch.tensor(1.0, requires_grad=True)
>>> c = a * b
>>> c1 = torch.nansum(c)
>>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
>>> bgrad1
tensor(nan)

:class:`MaskedTensor` result:

>>> a = torch.tensor([1., 2., float('nan')])
>>> b = torch.tensor(1.0, requires_grad=True)
>>> mt = masked_tensor(a, ~torch.isnan(a))
>>> c = mt * b
>>> c1 = torch.sum(c)
>>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
>>> bgrad1
MaskedTensor( 3.0000, True)

`Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__ -- when using mask, x/0 yields NaN grad
-------------------------------------------------------------------------------------------------------

Current result:

>>> x = torch.tensor([1., 1.], requires_grad=True)
>>> div = torch.tensor([0., 1.])
>>> y = x/div # => y is [inf, 1]
>>> mask = (div != 0) # => mask is [0, 1]
>>> y[mask].backward()
>>> x.grad # grad is [nan, 1], but expected [0, 1]
tensor([nan, 1.])

:class:`MaskedTensor` result:

>>> x = torch.tensor([1., 1.], requires_grad=True)
>>> div = torch.tensor([0., 1.])
>>> y = x/div # => y is [inf, 1]
>>>
>>> mask = (div != 0) # => mask is [0, 1]
>>> loss = as_masked_tensor(y, mask)
>>> loss.sum().backward()
>>> x.grad
MaskedTensor(
[ --, 1.0000]
)
32 changes: 32 additions & 0 deletions beginner_source/maskedtensor_missing_nan_ops.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Implementing missing torch.nan* operators
-----------------------------------------

In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications,
such as ``torch.nanmax``, ``torch.nanmin``, etc.

In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
operators, we propose using MaskedTensors instead. Since
`nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__, we can use it as a comparison point:

>>> x = torch.arange(16).float()
>>> y = x * x.fmod(4)
>>> y = y.masked_fill(y ==0, float('nan'))
>>> y
tensor([nan, 1., 4., 9., nan, 5., 12., 21., nan, 9., 20., 33., nan, 13.,
28., 45.])
>>> y.nanmean()
tensor(16.6667)
>>> torch.mean(masked_tensor(y, ~torch.isnan(y)))
MaskedTensor( 16.6667, True)

:class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
(an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.

>>> x = torch.empty(16).fill_(float('nan'))
>>> x
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> torch.nanmean(x)
tensor(nan)
>>> torch.mean(masked_tensor(x, ~torch.isnan(x)))
MaskedTensor(--, False)
Loading