Skip to content

Commit ab2d214

Browse files
authored
Merge pull request #251 from PyO3/parallel-example
Add new example showing how to use ndarray's parallel and blas-src features
2 parents af698f7 + 28183f1 commit ab2d214

File tree

17 files changed

+122
-34
lines changed

17 files changed

+122
-34
lines changed

.github/workflows/ci.yml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,14 @@ jobs:
2020
default: true
2121
- uses: Swatinem/rust-cache@v1
2222
continue-on-error: true
23-
- env:
24-
CLIPPYFLAGS: --deny warnings --allow clippy::needless-lifetimes
25-
run: |
23+
- run: |
2624
cargo fmt --all -- --check
27-
cargo clippy --tests -- $CLIPPYFLAGS
28-
for example in examples/*; do (cd $example/; cargo clippy -- $CLIPPYFLAGS) || exit 1; done
25+
cargo clippy --workspace --tests -- --deny warnings
2926
3027
test:
3128
name: python${{ matrix.python-version }}-${{ matrix.platform.python-architecture }} ${{ matrix.platform.os }}
3229
runs-on: ${{ matrix.platform.os }}
33-
needs: [lint, check-msrv, linalg-example]
30+
needs: [lint, check-msrv, examples]
3431
strategy:
3532
fail-fast: false
3633
matrix:
@@ -84,8 +81,7 @@ jobs:
8481
- name: Test example
8582
run: |
8683
pip install tox
87-
tox
88-
working-directory: examples/simple-extension
84+
tox -c examples/simple-extension
8985
env:
9086
CARGO_TERM_VERBOSE: true
9187
CARGO_BUILD_TARGET: ${{ matrix.platform.rust-target }}
@@ -143,7 +139,7 @@ jobs:
143139
tox
144140
working-directory: examples/simple-extension
145141

146-
linalg-example:
142+
examples:
147143
runs-on: ubuntu-latest
148144
steps:
149145
- uses: actions/checkout@v2
@@ -161,8 +157,8 @@ jobs:
161157
default: true
162158
- uses: Swatinem/rust-cache@v1
163159
continue-on-error: true
164-
- name: Test example
160+
- name: Test examples
165161
run: |
166162
pip install tox
167-
tox
168-
working-directory: examples/linalg
163+
tox -c examples/linalg
164+
tox -c examples/parallel

examples/linalg/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
An example extension with [ndarray-linalg](https://github.com/rust-ndarray/ndarray-linalg).
44

5-
Needs a Fortran compiler (e.g., `gfortran`) for building.
5+
Will link against a system-provided OpenBLAS.
66

77
See [simple-extension's README](https://github.com/PyO3/rust-numpy/blob/main/examples/simple-extension/README.md)
88
for an introduction.
9-

examples/linalg/src/lib.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
use ndarray_linalg::solve::Inverse;
22
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
3-
use pyo3::exceptions::PyRuntimeError;
4-
use pyo3::prelude::{pymodule, PyErr, PyModule, PyResult, Python};
3+
use pyo3::{exceptions::PyRuntimeError, pymodule, types::PyModule, PyErr, PyResult, Python};
54

65
#[pymodule]
76
fn rust_linalg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
87
#[pyfn(m)]
98
fn inv<'py>(py: Python<'py>, x: PyReadonlyArray2<'py, f64>) -> PyResult<&'py PyArray2<f64>> {
10-
let x = x
11-
.as_array()
9+
let x = x.as_array();
10+
let y = x
1211
.inv()
1312
.map_err(|e| PyErr::new::<PyRuntimeError, _>(format!("[rust_linalg] {}", e)))?;
14-
Ok(x.into_pyarray(py))
13+
Ok(y.into_pyarray(py))
1514
}
1615
Ok(())
1716
}

examples/linalg/tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ deps =
77
numpy
88
pytest
99
commands =
10-
pip install .
10+
pip install . -v
1111
pytest {posargs}

examples/parallel/Cargo.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[package]
2+
name = "numpy-parallel-example"
3+
version = "0.1.0"
4+
authors = ["Yuji Kanagawa <[email protected]>"]
5+
edition = "2018"
6+
7+
[lib]
8+
name = "rust_parallel"
9+
crate-type = ["cdylib"]
10+
11+
[dependencies]
12+
pyo3 = { version = "0.15", features = ["extension-module"] }
13+
numpy = { path = "../.." }
14+
ndarray = { version = "0.15", features = ["rayon", "blas"] }
15+
blas-src = { version = "0.8", features = ["openblas"] }
16+
openblas-src = { version = "0.10", features = ["cblas", "system"] }

examples/parallel/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# rust-numpy example extension using optional ndarray features
2+
3+
An example extension using [optional ndarray features](https://docs.rs/ndarray/latest/ndarray/doc/crate_feature_flags/index.html), parallel execution using Rayon and optimized kernels using BLAS in this case.
4+
5+
See [simple-extension's README](https://github.com/PyO3/rust-numpy/blob/main/examples/simple-extension/README.md)
6+
for an introduction.
7+

examples/parallel/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[build-system]
2+
build-backend = "maturin"
3+
requires = ["maturin>=0.12,<0.13"]

examples/parallel/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// We need to link `blas_src` directly, c.f. https://github.com/rust-ndarray/ndarray#how-to-enable-blas-integration
2+
extern crate blas_src;
3+
4+
use ndarray::Zip;
5+
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
6+
use pyo3::{pymodule, types::PyModule, PyResult, Python};
7+
8+
#[pymodule]
9+
fn rust_parallel(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
10+
#[pyfn(m)]
11+
fn rows_dot<'py>(
12+
py: Python<'py>,
13+
x: PyReadonlyArray2<'py, f64>,
14+
y: PyReadonlyArray1<'py, f64>,
15+
) -> &'py PyArray1<f64> {
16+
let x = x.as_array();
17+
let y = y.as_array();
18+
let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(&y));
19+
z.into_pyarray(py)
20+
}
21+
Ok(())
22+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import numpy as np
2+
import rust_parallel
3+
4+
5+
def test_rows_dot():
6+
x = np.ones((128, 1024), dtype=np.float64)
7+
y = np.ones((1024,), dtype=np.float64)
8+
z = rust_parallel.rows_dot(x, y)
9+
np.testing.assert_array_almost_equal(z, 1024 * np.ones((128,), dtype=np.float64))

examples/parallel/tox.ini

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[tox]
2+
skipsdist = True
3+
4+
[testenv]
5+
deps =
6+
pip
7+
numpy
8+
pytest
9+
commands =
10+
pip install . -v
11+
pytest {posargs}

examples/simple-extension/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ crate-type = ["cdylib"]
1111
[dependencies]
1212
pyo3 = { version = "0.15", features = ["extension-module"] }
1313
numpy = { path = "../.." }
14-
num-complex = "0.4.0"

examples/simple-extension/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
22
use numpy::{Complex64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
3-
use pyo3::prelude::{pymodule, PyModule, PyResult, Python};
3+
use pyo3::{pymodule, types::PyModule, PyResult, Python};
44

55
#[pymodule]
66
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -30,7 +30,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
3030
) -> &'py PyArrayDyn<f64> {
3131
let x = x.as_array();
3232
let y = y.as_array();
33-
axpy(a, x, y).into_pyarray(py)
33+
let z = axpy(a, x, y);
34+
z.into_pyarray(py)
3435
}
3536

3637
// wrapper of `mult`

examples/simple-extension/tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ deps =
77
numpy
88
pytest
99
commands =
10-
python -m pip install .
10+
pip install . -v
1111
pytest {posargs}

src/array.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,24 @@ impl<T, D> PyArray<T, D> {
244244
unsafe { Py::from_borrowed_ptr(self.py(), self.as_ptr()) }
245245
}
246246

247-
/// Constructs `PyArray` from raw python object without incrementing reference counts.
247+
/// Constructs `PyArray` from raw Python object without incrementing reference counts.
248+
///
249+
/// # Safety
250+
///
251+
/// Implementations must ensure the object does not get freed during `'py`
252+
/// and ensure that `ptr` is of the correct type.
248253
pub unsafe fn from_owned_ptr(py: Python<'_>, ptr: *mut ffi::PyObject) -> &Self {
249254
py.from_owned_ptr(ptr)
250255
}
251256

252-
/// Constructs PyArray from raw python object and increments reference counts.
253-
pub unsafe fn from_borrowed_ptr(py: Python<'_>, ptr: *mut ffi::PyObject) -> &Self {
257+
/// Constructs PyArray from raw Python object and increments reference counts.
258+
///
259+
/// # Safety
260+
///
261+
/// Implementations must ensure the object does not get freed during `'py`
262+
/// and ensure that `ptr` is of the correct type.
263+
/// Note that it must be safe to decrement the reference count of ptr.
264+
pub unsafe fn from_borrowed_ptr<'py>(py: Python<'py>, ptr: *mut ffi::PyObject) -> &'py Self {
254265
py.from_borrowed_ptr(ptr)
255266
}
256267

@@ -673,7 +684,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
673684
///
674685
/// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index.
675686
///
676-
/// Passing an invalid index can cause undefined behavior(mostly SIGSEGV).
687+
/// # Safety
688+
///
689+
/// Passing an invalid index is undefined behavior. The element must also have been initialized.
690+
/// The elemet must also not be modified by Python code.
677691
///
678692
/// # Example
679693
/// ```
@@ -693,6 +707,11 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
693707
}
694708

695709
/// Same as [uget](#method.uget), but returns `&mut T`.
710+
///
711+
/// # Safety
712+
///
713+
/// Passing an invalid index is undefined behavior. The element must also have been initialized.
714+
/// The element must also not be accessed by Python code.
696715
#[inline(always)]
697716
#[allow(clippy::mut_from_ref)]
698717
pub unsafe fn uget_mut<Idx>(&self, index: Idx) -> &mut T
@@ -704,6 +723,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
704723
}
705724

706725
/// Same as [uget](#method.uget), but returns `*mut T`.
726+
///
727+
/// # Safety
728+
/// Passing an invalid index is undefined behavior.
707729
#[inline(always)]
708730
pub unsafe fn uget_raw<Idx>(&self, index: Idx) -> *mut T
709731
where

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#![allow(clippy::missing_safety_doc)] // FIXME
2-
31
//! `rust-numpy` provides Rust interfaces for [NumPy C APIs](https://numpy.org/doc/stable/reference/c-api),
42
//! especially for [ndarray](https://numpy.org/doc/stable/reference/arrays.ndarray.html) class.
53
//!
@@ -29,6 +27,8 @@
2927
//! })
3028
//! }
3129
//! ```
30+
#![allow(clippy::needless_lifetimes)] // We often want to make the GIL lifetime explicit.
31+
3232
pub mod array;
3333
pub mod convert;
3434
mod dtype;

src/npyffi/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
//! Low-Level bindings for NumPy C API.
22
//!
33
//! <https://numpy.org/doc/stable/reference/c-api>
4-
#![allow(non_camel_case_types, clippy::too_many_arguments)]
4+
#![allow(
5+
non_camel_case_types,
6+
clippy::too_many_arguments,
7+
clippy::missing_safety_doc
8+
)]
59

610
use pyo3::{ffi, Python};
711
use std::ffi::CString;

tests/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ use pyo3::{
66
types::{IntoPyDict, PyDict},
77
};
88

9-
fn get_np_locals(py: Python<'_>) -> &'_ PyDict {
9+
fn get_np_locals(py: Python) -> &PyDict {
1010
[("np", get_array_module(py).unwrap())].into_py_dict(py)
1111
}
1212

13-
fn not_contiguous_array<'py>(py: Python<'py>) -> &'py PyArray1<i32> {
13+
fn not_contiguous_array(py: Python) -> &PyArray1<i32> {
1414
py.eval(
1515
"np.array([1, 2, 3, 4], dtype='int32')[::2]",
1616
Some(get_np_locals(py)),
@@ -266,7 +266,7 @@ fn borrow_from_array() {
266266
#[pymethods]
267267
impl Owner {
268268
#[getter]
269-
fn array<'py>(this: &'py PyCell<Self>) -> &'py PyArray1<f64> {
269+
fn array(this: &PyCell<Self>) -> &PyArray1<f64> {
270270
let array = &this.borrow().array;
271271

272272
unsafe { PyArray1::borrow_from_array(array, this) }

0 commit comments

Comments
 (0)