Skip to content

Calling spspmm twice gives CUDA error: an illegal memory access was encountered #174

Open
@patmjen

Description

@patmjen

Summary

Running spspmm two times with the same inputs gives RuntimeError: CUDA error: an illegal memory access was encountered.

The following snippet shows the issue for me:

import torch
from torch_sparse import spspmm

# device = torch.device('cpu')  # This works
device = torch.device('cuda')  # This will error

# Make two simple sparse matrices
A_idx = torch.tensor([[0, 1], [0, 1]])
A_val = torch.tensor([1, 1]).float()

B_idx = torch.tensor([[0, 0, 1], [0, 1, 1]])
B_val = torch.tensor([2, 3, 4]).float()

# Transfer to device
print(f'To {device}')
A_idx = A_idx.to(device)
A_val = A_val.to(device)
B_idx = B_idx.to(device)
B_val = B_val.to(device)

# Do matrix multiplies
print('spspmm 1')
spspmm(A_idx, A_val, B_idx, B_val, 2, 2, 2, coalesced=True)  # This works
print('spspmm 2')
spspmm(A_idx, A_val, B_idx, B_val, 2, 2, 2, coalesced=True)  # On CUDA, this errors

When I run the above code, I get the following error:

To cuda
spspmm 1
spspmm 2
Traceback (most recent call last):
  File "sparsebug.py", line 25, in <module>
    spspmm(A_idx, A_val, B_idx, B_val, 2, 2, 2, )  # On CUDA, this errors
  File "venv/lib/python3.8/site-packages/torch_sparse/spspmm.py", line 30, in spspmm
    C = matmul(A, B)
  File "venv/lib/python3.8/site-packages/torch_sparse/matmul.py", line 139, in matmul
    return spspmm(src, other, reduce)
  File "venv/lib/python3.8/site-packages/torch_sparse/matmul.py", line 116, in spspmm
    return spspmm_sum(src, other)
  File "venv/lib/python3.8/site-packages/torch_sparse/matmul.py", line 101, in spspmm_sum
    rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Sorry if this just me using the library wrongly! Is there something I should be doing in between calls to spspmm? Or any other way to fix it?

Environment

$ python collect_env.py
Collecting environment information...
PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Scientific Linux release 7.7 (Nitrogen) (x86_64)
GCC version: (GCC) 8.3.0
Clang version: Could not collect
CMake version: version 2.8.12.2
Libc version: glibc-2.17

Python version: 3.8.4 (default, Jul 16 2020, 09:01:13)  [GCC 8.4.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.36.2.el7.x86_64-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: 11.1.74
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB

Nvidia driver version: 470.42.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch3d==0.5.0
[pip3] torch==1.9.0+cu111
[pip3] torch-scatter==2.0.8
[pip3] torch-sparse==0.6.12
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions