Skip to content

Commit 6ab5c33

Browse files
committed
build separate cpu and cuda images
1 parent 39d463b commit 6ab5c33

File tree

16 files changed

+116
-93
lines changed

16 files changed

+116
-93
lines changed

.travis.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,16 @@ install:
112112
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
113113
- source script/torch.sh
114114
- pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${IDX}.html
115-
- pip install flake8 codecov
115+
- pip install flake8
116+
- pip install codecov
116117
- pip install scipy==1.4.1
117-
- source script/install.sh
118+
- travis_wait 30 pip install -e .
118119
script:
119120
- flake8 .
120121
- python setup.py test
121122
after_success:
122-
- python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION}
123-
- python script/rename_wheel.py ${IDX}
123+
- python setup.py bdist_wheel --dist-dir=dist
124+
- ls -lah dist/
124125
- codecov
125126
deploy:
126127
provider: s3
@@ -129,8 +130,8 @@ deploy:
129130
access_key_id: ${S3_ACCESS_KEY}
130131
secret_access_key: ${S3_SECRET_ACCESS_KEY}
131132
bucket: pytorch-geometric.com
132-
local_dir: dist/torch-${TORCH_VERSION}
133-
upload_dir: whl/torch-${TORCH_VERSION}
133+
local_dir: dist
134+
upload_dir: whl/torch-${TORCH_VERSION}+${IDX}
134135
acl: public_read
135136
on:
136137
all_branches: true

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.0)
22
project(torchsparse)
33
set(CMAKE_CXX_STANDARD 14)
4-
set(TORCHSPARSE_VERSION 0.6.8)
4+
set(TORCHSPARSE_VERSION 0.6.9)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
77

csrc/convert.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#endif
99

1010
#ifdef _WIN32
11-
PyMODINIT_FUNC PyInit__convert(void) { return NULL; }
11+
#ifdef WITH_CUDA
12+
PyMODINIT_FUNC PyInit__convert_cuda(void) { return NULL; }
13+
#else
14+
PyMODINIT_FUNC PyInit__convert_cpu(void) { return NULL; }
15+
#endif
1216
#endif
1317

1418
torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {

csrc/diag.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#endif
99

1010
#ifdef _WIN32
11-
PyMODINIT_FUNC PyInit__diag(void) { return NULL; }
11+
#ifdef WITH_CUDA
12+
PyMODINIT_FUNC PyInit__diag_cuda(void) { return NULL; }
13+
#else
14+
PyMODINIT_FUNC PyInit__diag_cpu(void) { return NULL; }
15+
#endif
1216
#endif
1317

1418
torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,

csrc/metis.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
#include "cpu/metis_cpu.h"
55

66
#ifdef _WIN32
7-
PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
7+
#ifdef WITH_CUDA
8+
PyMODINIT_FUNC PyInit__metis_cuda(void) { return NULL; }
9+
#else
10+
PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
11+
#endif
812
#endif
913

1014
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,

csrc/relabel.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
#include "cpu/relabel_cpu.h"
55

66
#ifdef _WIN32
7-
PyMODINIT_FUNC PyInit__relabel(void) { return NULL; }
7+
#ifdef WITH_CUDA
8+
PyMODINIT_FUNC PyInit__relablel_cuda(void) { return NULL; }
9+
#else
10+
PyMODINIT_FUNC PyInit__relabel_cpu(void) { return NULL; }
11+
#endif
812
#endif
913

1014
std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,

csrc/rw.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#endif
99

1010
#ifdef _WIN32
11-
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
11+
#ifdef WITH_CUDA
12+
PyMODINIT_FUNC PyInit__rw_cuda(void) { return NULL; }
13+
#else
14+
PyMODINIT_FUNC PyInit__rw_cpu(void) { return NULL; }
15+
#endif
1216
#endif
1317

1418
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,

csrc/saint.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
#include "cpu/saint_cpu.h"
55

66
#ifdef _WIN32
7-
PyMODINIT_FUNC PyInit__saint(void) { return NULL; }
7+
#ifdef WITH_CUDA
8+
PyMODINIT_FUNC PyInit__saint_cuda(void) { return NULL; }
9+
#else
10+
PyMODINIT_FUNC PyInit__saint_cpu(void) { return NULL; }
11+
#endif
812
#endif
913

1014
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>

csrc/sample.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
#include "cpu/sample_cpu.h"
55

66
#ifdef _WIN32
7-
PyMODINIT_FUNC PyInit__sample(void) { return NULL; }
7+
#ifdef WITH_CUDA
8+
PyMODINIT_FUNC PyInit__sample_cuda(void) { return NULL; }
9+
#else
10+
PyMODINIT_FUNC PyInit__sample_cpu(void) { return NULL; }
11+
#endif
812
#endif
913

1014
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>

csrc/spmm.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#endif
99

1010
#ifdef _WIN32
11-
PyMODINIT_FUNC PyInit__spmm(void) { return NULL; }
11+
#ifdef WITH_CUDA
12+
PyMODINIT_FUNC PyInit__spmm_cuda(void) { return NULL; }
13+
#else
14+
PyMODINIT_FUNC PyInit__spmm_cpu(void) { return NULL; }
15+
#endif
1216
#endif
1317

1418
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>

csrc/spspmm.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#endif
99

1010
#ifdef _WIN32
11-
PyMODINIT_FUNC PyInit__spspmm(void) { return NULL; }
11+
#ifdef WITH_CUDA
12+
PyMODINIT_FUNC PyInit__spspmm_cuda(void) { return NULL; }
13+
#else
14+
PyMODINIT_FUNC PyInit__spspmm_cpu(void) { return NULL; }
15+
#endif
1216
#endif
1317

1418
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>

csrc/version.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
#endif
77

88
#ifdef _WIN32
9-
PyMODINIT_FUNC PyInit__version(void) { return NULL; }
9+
#ifdef WITH_CUDA
10+
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
11+
#else
12+
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
13+
#endif
1014
#endif
1115

1216
int64_t cuda_version() {

script/cuda.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ if [ "${TRAVIS_OS_NAME}" = "osx" ] && [ "$IDX" = "cpu" ]; then
6969
fi
7070

7171
if [ "${IDX}" = "cpu" ]; then
72-
export FORCE_CPU=1
72+
export FORCE_ONLY_CPU=1
7373
else
7474
export FORCE_CUDA=1
7575
fi

script/rename_wheel.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

setup.py

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
2-
import os.path as osp
32
import sys
43
import glob
4+
import os.path as osp
5+
from itertools import product
56
from setuptools import setup, find_packages
67

78
import torch
@@ -10,10 +11,13 @@
1011
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
1112

1213
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
14+
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
1315
if os.getenv('FORCE_CUDA', '0') == '1':
14-
WITH_CUDA = True
15-
if os.getenv('FORCE_CPU', '0') == '1':
16-
WITH_CUDA = False
16+
suffices = ['cuda', 'cpu']
17+
if os.getenv('FORCE_ONLY_CUDA', '0') == '1':
18+
suffices = ['cuda']
19+
if os.getenv('FORCE_ONLY_CPU', '0') == '1':
20+
suffices = ['cpu']
1721

1822
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
1923

@@ -22,63 +26,63 @@
2226

2327

2428
def get_extensions():
25-
Extension = CppExtension
26-
define_macros = []
27-
libraries = []
28-
if WITH_METIS:
29-
define_macros += [('WITH_METIS', None)]
30-
libraries += ['metis']
31-
if WITH_MTMETIS:
32-
define_macros += [('WITH_MTMETIS', None)]
33-
define_macros += [('MTMETIS_64BIT_VERTICES', None)]
34-
define_macros += [('MTMETIS_64BIT_EDGES', None)]
35-
define_macros += [('MTMETIS_64BIT_WEIGHTS', None)]
36-
define_macros += [('MTMETIS_64BIT_PARTITIONS', None)]
37-
libraries += ['mtmetis', 'wildriver']
38-
extra_compile_args = {'cxx': ['-O2']}
39-
extra_link_args = ['-s']
40-
41-
info = parallel_info()
42-
if 'parallel backend: OpenMP' in info and 'OpenMP not found' not in info:
43-
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
44-
if sys.platform == 'win32':
45-
extra_compile_args['cxx'] += ['/openmp']
46-
else:
47-
extra_compile_args['cxx'] += ['-fopenmp']
48-
else:
49-
print('Compiling without OpenMP...')
50-
51-
if WITH_CUDA:
52-
Extension = CUDAExtension
53-
define_macros += [('WITH_CUDA', None)]
54-
nvcc_flags = os.getenv('NVCC_FLAGS', '')
55-
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
56-
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2']
57-
extra_compile_args['nvcc'] = nvcc_flags
58-
59-
if sys.platform == 'win32':
60-
extra_link_args += ['cusparse.lib']
61-
else:
62-
extra_link_args += ['-lcusparse', '-l', 'cusparse']
29+
extensions = []
6330

6431
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
6532
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
66-
extensions = []
67-
for main in main_files:
68-
name = main.split(os.sep)[-1][:-4]
6933

34+
for main, suffix in product(main_files, suffices):
35+
define_macros = []
36+
libraries = []
37+
if WITH_METIS:
38+
define_macros += [('WITH_METIS', None)]
39+
libraries += ['metis']
40+
if WITH_MTMETIS:
41+
define_macros += [('WITH_MTMETIS', None)]
42+
define_macros += [('MTMETIS_64BIT_VERTICES', None)]
43+
define_macros += [('MTMETIS_64BIT_EDGES', None)]
44+
define_macros += [('MTMETIS_64BIT_WEIGHTS', None)]
45+
define_macros += [('MTMETIS_64BIT_PARTITIONS', None)]
46+
libraries += ['mtmetis', 'wildriver']
47+
extra_compile_args = {'cxx': ['-O2']}
48+
extra_link_args = ['-s']
49+
50+
info = parallel_info()
51+
if 'backend: OpenMP' in info and 'OpenMP not found' not in info:
52+
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
53+
if sys.platform == 'win32':
54+
extra_compile_args['cxx'] += ['/openmp']
55+
else:
56+
extra_compile_args['cxx'] += ['-fopenmp']
57+
else:
58+
print('Compiling without OpenMP...')
59+
60+
if suffix == 'cuda':
61+
define_macros += [('WITH_CUDA', None)]
62+
nvcc_flags = os.getenv('NVCC_FLAGS', '')
63+
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
64+
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2']
65+
extra_compile_args['nvcc'] = nvcc_flags
66+
67+
if sys.platform == 'win32':
68+
extra_link_args += ['cusparse.lib']
69+
else:
70+
extra_link_args += ['-lcusparse', '-l', 'cusparse']
71+
72+
name = main.split(os.sep)[-1][:-4]
7073
sources = [main]
7174

7275
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
7376
if osp.exists(path):
7477
sources += [path]
7578

7679
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
77-
if WITH_CUDA and osp.exists(path):
80+
if suffix == 'cuda' and osp.exists(path):
7881
sources += [path]
7982

83+
Extension = CppExtension if suffix == 'cpu' else CUDAExtension
8084
extension = Extension(
81-
'torch_sparse._' + name,
85+
f'torch_sparse._{name}_{suffix}',
8286
sources,
8387
include_dirs=[extensions_dir],
8488
define_macros=define_macros,
@@ -97,7 +101,7 @@ def get_extensions():
97101

98102
setup(
99103
name='torch_sparse',
100-
version='0.6.8',
104+
version='0.6.9',
101105
author='Matthias Fey',
102106
author_email='[email protected]',
103107
url='https://github.com/rusty1s/pytorch_sparse',

torch_sparse/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33

44
import torch
55

6-
__version__ = '0.6.8'
6+
__version__ = '0.6.9'
7+
8+
suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
79

810
for library in [
911
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
1012
'_saint', '_sample', '_relabel'
1113
]:
1214
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
13-
library, [osp.dirname(__file__)]).origin)
15+
f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
1416

15-
if torch.cuda.is_available() and torch.version.cuda: # pragma: no cover
17+
if torch.cuda.is_available(): # pragma: no cover
1618
cuda_version = torch.ops.torch_sparse.cuda_version()
1719

1820
if cuda_version == -1:

0 commit comments

Comments
 (0)