Skip to content

doc: update pytorch-on-xla-devices and troubleshoot doc for tensor synchronization issue #9258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/source/learn/pytorch-on-xla-devices.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,72 @@ device is unavailable the load will fail. PyTorch/XLA, like all of
PyTorch, is under active development and this behavior may change in the
future.

### Tensor Synchronization During Tracing

While tensor synchronization is normal for JIT workflow, it is not expected during traced inference (i.e. traced workflow in AWS Neuron).
When working with traced graphs, developers may encounter issues related to tensor synchronization during tracing, which can lead to additional graphs being compiled and unexpected program behavior.
Therefore we need to take advantage of PyTorch/XLA's debugging flags to identify when unexpected tensor synchronization happens and make appropriate code changes to avoid tensor synchronization.


A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example:
```python
def forward(self, tensor):
if tensor[0] == 1:
return tensor
else:
return tensor * 2
```

While this code can compile and run, it may lead to unexpected behavior because:

* The tensor value is being accessed during tracing (``tensor[0]``).
* The resulting graph becomes fixed based on the tensor value available during tracing
* Developers might incorrectly assume the condition will be evaluated dynamically during inference
* The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code. One example is to feed the flag through model configuration

See the updated code without tensor synchronization:
```python
class TestModel(torch.nn.Module):
def __init__(self, flag=1):
super().__init__()
# the flag should be pre-determined based on the model configuration
# it should not be an input of the model during runtime
self.flag = flag

def forward(self, tensor):
if self.flag:
return tensor
else:
return tensor * 2
```


#### Debugging Flags
To help catch tensor synchronization issues, PyTorch/XLA provides two useful approaches:

1. Enable warning messages for tensor synchronization:
```
import os
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
```

2. Disable graph execution to catch issues during development:
```
import torch_xla
torch_xla._XLAC._set_allow_execution(False)
```

#### Recommendations

Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to:

* Use ``PT_XLA_DEBUG_LEVEL=2`` during initial development to identify potential synchronization points
* Apply ``_set_allow_execution(False)`` when you want to ensure no tensor synchronization occurs during tracing
* When you see warnings or errors related the tensor synchronization, look into the code path and make appropriate changes. The example above moved the flag to the `__init__` function which does not depend on the model input during runtime.

For more detailed debugging information, refer to the [XLA troubleshoot](https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool).


## Compilation Caching

The XLA compiler converts the traced HLO into an executable which runs
Expand Down
20 changes: 13 additions & 7 deletions docs/source/learn/troubleshoot.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,25 @@ Execution Analysis: ------------------------------------------------------------
Execution Analysis: ================================================================================
```

Some common causes of Compilation/Executation are 1. User manually call
`torch_xla.sync()`. 2. [Parallel
Some common causes of compilation/executation are
1. User manually calls
`torch_xla.sync()`.
2. [Parallel
loader](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/distributed/parallel_loader.py#L49-L51)
call `torch_xla.sync()` for every x (configurable) batch. 3. Exiting a
call `torch_xla.sync()` for every x (configurable) batch.
3. Exiting a
[profiler StepTrace
region](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/debug/profiler.py#L165-L171).
4. Dynamo decide to compile/execute the graph. 5. User trying to
4. Dynamo decides to compile/execute the graph.
5. User tries to
access(often due to logging) the value of a tensor before the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space needed after "access"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. Thanks.

`torch_xla.sync()`.
6. User tries to a tensor value before calling `mark_step`. See [PyTorch on XLA Devices](https://github.com/pytorch/xla/blob/master/docs/source/learn/pytorch-on-xla-devices.md) for more details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User tries to access a tensor value ...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. Thanks.


The op executions caused by items 1-4 are expected, and we want to avoid item 5 by
either reducing the frequency of accessing tensor values or manually adding a call to
`torch_xla.sync()` before accessing them.

The execution caused by 1-4 are expected, and we want to avoid 5 by
either reduce the frequency of accessing tensor values or manually add a
`torch_xla.sync()` before accessing.

Users should expect to see this `Compilation Cause` +
`Executation Cause` pairs for first couple steps. After the model
Expand Down