-
Notifications
You must be signed in to change notification settings - Fork 4.2k
[maskedtensor] Add safe softmax tutorial #2045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
Safe Softmax | ||
------------ | ||
|
||
One of the issues that frequently comes up is the necessity for a safe softmax -- that is, if there is an entire | ||
batch that is "masked out" or consists entirely of padding (which, in the softmax case, translates to being set `-inf`), | ||
then this will result in NaNs, which can leading to training divergence. For more detail on why this functionality | ||
is necessary, please find refer to | ||
`Issue 55056 - Feature Request for Safe Softmax <https://github.com/pytorch/pytorch/issues/55056>`__. | ||
|
||
Luckily, :class:`MaskedTensor` has solved this issue: | ||
|
||
>>> data = torch.randn(3, 3) | ||
>>> mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]]) | ||
>>> x = data.masked_fill(~mask, float('-inf')) | ||
>>> mt = masked_tensor(data, mask) | ||
|
||
PyTorch result: | ||
|
||
>>> x.softmax(0) | ||
tensor([[0.3548, nan, 0.0000], | ||
[0.6452, nan, 1.0000], | ||
[0.0000, nan, 0.0000]]) | ||
|
||
:class:`MaskedTensor` result: | ||
|
||
>>> mt.softmax(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused on this result. The softmax above returns two different values that are getting masked here: 0.0 and nan. Do we really want to mask out both outputs? I assume the idea is that we want to mask all input padding (aka mask all -infs). but in situations of unstable training, do 0's make a difference? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the 0's make a difference in unstable training, but maybe I'm not quite understanding the question. In general for MaskedTensor, if the user has deemed a particular index as "unspecified", then it will remain unspecified or masked out in both the forward as well as the backward. And so in this case, while the underlying data that is being masked out would have been different (i.e. 0.0 vs nan), they're both unspecified and so we still want to mask them out. For training specifically, I agree that leaving the 0's would be identical, but I think it would contradict MaskedTensor semantics and become confusing (e.g. wouldn't know when to expect masked out values vs not.) |
||
MaskedTensor( | ||
[ | ||
[ 0.3548, --, --], | ||
[ 0.6452, --, 1.0000], | ||
[ --, --, --] | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would explain a bit more here or inline what you are doing:
Let's create a Tensor with an "unsafe" column, that is a column of all -inf padding values, for which PyTorch's traditional softmax will result in nans.
Now let's create a safeter MaskedTensor from this Tensor, and show that the resultant values are masked out