Skip to content

Commit 9ce12bf

Browse files
committed
Flesh out borrow checking examples and tests and remove unreachable case in implementation.
1 parent 028286f commit 9ce12bf

File tree

2 files changed

+263
-7
lines changed

2 files changed

+263
-7
lines changed

src/borrow.rs

+252-7
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,30 @@
8686
//! });
8787
//! ```
8888
//!
89+
//! The third example shows that some views are incorrectly rejected since the borrows are over-approximated.
90+
//!
91+
//! ```rust
92+
//! # use std::panic::{catch_unwind, AssertUnwindSafe};
93+
//! #
94+
//! use numpy::PyArray2;
95+
//! use pyo3::{types::IntoPyDict, Python};
96+
//!
97+
//! Python::with_gil(|py| {
98+
//! let array = PyArray2::<f64>::zeros(py, (10, 10), false);
99+
//! let locals = [("array", array)].into_py_dict(py);
100+
//!
101+
//! let view1 = py.eval("array[:, ::3]", None, Some(locals)).unwrap().downcast::<PyArray2<f64>>().unwrap();
102+
//! let view2 = py.eval("array[:, 1::3]", None, Some(locals)).unwrap().downcast::<PyArray2<f64>>().unwrap();
103+
//!
104+
//! // A false conflict as the views do not actually share any elements.
105+
//! let res = catch_unwind(AssertUnwindSafe(|| {
106+
//! let _view1 = view1.readwrite();
107+
//! let _view2 = view2.readwrite();
108+
//! }));
109+
//! assert!(res.is_err());
110+
//! });
111+
//! ```
112+
//!
89113
//! # Rationale
90114
//!
91115
//! Rust references require aliasing discipline to be maintained, i.e. there must always
@@ -235,6 +259,9 @@ impl BorrowFlags {
235259
let same_base = entry.into_mut();
236260

237261
if let Some(readers) = same_base.get_mut(&key) {
262+
// Zero flags are removed during release.
263+
assert_ne!(*readers, 0);
264+
238265
let new_readers = readers.wrapping_add(1);
239266

240267
if new_readers <= 0 {
@@ -309,12 +336,11 @@ impl BorrowFlags {
309336
let same_base = entry.into_mut();
310337

311338
if let Some(writers) = same_base.get_mut(&key) {
312-
if *writers != 0 {
313-
cold();
314-
return Err(BorrowError::AlreadyBorrowed);
315-
}
339+
// Zero flags are removed during release.
340+
assert_ne!(*writers, 0);
316341

317-
*writers = -1;
342+
cold();
343+
return Err(BorrowError::AlreadyBorrowed);
318344
} else {
319345
if same_base
320346
.iter()
@@ -620,8 +646,6 @@ fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
620646
}
621647
}
622648

623-
// FIXME(adamreichold): This is a coarse approximation and needs to be refined,
624-
// i.e. borrows of interleaved views into the same base should not be considered conflicting.
625649
fn data_range<T, D>(array: &PyArray<T, D>) -> Range<usize>
626650
where
627651
T: Element,
@@ -934,4 +958,225 @@ mod tests {
934958
assert!(key1.conflicts(&key2));
935959
});
936960
}
961+
962+
#[test]
963+
fn borrow_multiple_arrays() {
964+
Python::with_gil(|py| {
965+
let array1 = PyArray::<f64, _>::zeros(py, 10, false);
966+
let array2 = PyArray::<f64, _>::zeros(py, 10, false);
967+
968+
let base1 = base_address(array1);
969+
let base2 = base_address(array2);
970+
971+
let key1 = BorrowKey::from_array(array1);
972+
let _exclusive1 = array1.readwrite();
973+
974+
{
975+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
976+
assert_eq!(borrow_flags.len(), 1);
977+
978+
let same_base = &borrow_flags[&base1];
979+
assert_eq!(same_base.len(), 1);
980+
981+
let flag = same_base[&key1];
982+
assert_eq!(flag, -1);
983+
}
984+
985+
let key2 = BorrowKey::from_array(array2);
986+
let _shared2 = array2.readonly();
987+
988+
{
989+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
990+
assert_eq!(borrow_flags.len(), 2);
991+
992+
let same_base = &borrow_flags[&base1];
993+
assert_eq!(same_base.len(), 1);
994+
995+
let flag = same_base[&key1];
996+
assert_eq!(flag, -1);
997+
998+
let same_base = &borrow_flags[&base2];
999+
assert_eq!(same_base.len(), 1);
1000+
1001+
let flag = same_base[&key2];
1002+
assert_eq!(flag, 1);
1003+
}
1004+
});
1005+
}
1006+
1007+
#[test]
1008+
fn borrow_multiple_views() {
1009+
Python::with_gil(|py| {
1010+
let array = PyArray::<f64, _>::zeros(py, 10, false);
1011+
let base = base_address(array);
1012+
1013+
let locals = [("array", array)].into_py_dict(py);
1014+
1015+
let view1 = py
1016+
.eval("array[:5]", None, Some(locals))
1017+
.unwrap()
1018+
.downcast::<PyArray1<f64>>()
1019+
.unwrap();
1020+
1021+
let key1 = BorrowKey::from_array(view1);
1022+
let exclusive1 = view1.readwrite();
1023+
1024+
{
1025+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1026+
assert_eq!(borrow_flags.len(), 1);
1027+
1028+
let same_base = &borrow_flags[&base];
1029+
assert_eq!(same_base.len(), 1);
1030+
1031+
let flag = same_base[&key1];
1032+
assert_eq!(flag, -1);
1033+
}
1034+
1035+
let view2 = py
1036+
.eval("array[5:]", None, Some(locals))
1037+
.unwrap()
1038+
.downcast::<PyArray1<f64>>()
1039+
.unwrap();
1040+
1041+
let key2 = BorrowKey::from_array(view2);
1042+
let shared2 = view2.readonly();
1043+
1044+
{
1045+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1046+
assert_eq!(borrow_flags.len(), 1);
1047+
1048+
let same_base = &borrow_flags[&base];
1049+
assert_eq!(same_base.len(), 2);
1050+
1051+
let flag = same_base[&key1];
1052+
assert_eq!(flag, -1);
1053+
1054+
let flag = same_base[&key2];
1055+
assert_eq!(flag, 1);
1056+
}
1057+
1058+
let view3 = py
1059+
.eval("array[5:]", None, Some(locals))
1060+
.unwrap()
1061+
.downcast::<PyArray1<f64>>()
1062+
.unwrap();
1063+
1064+
let key3 = BorrowKey::from_array(view3);
1065+
let shared3 = view3.readonly();
1066+
1067+
{
1068+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1069+
assert_eq!(borrow_flags.len(), 1);
1070+
1071+
let same_base = &borrow_flags[&base];
1072+
assert_eq!(same_base.len(), 2);
1073+
1074+
let flag = same_base[&key1];
1075+
assert_eq!(flag, -1);
1076+
1077+
let flag = same_base[&key2];
1078+
assert_eq!(flag, 2);
1079+
1080+
let flag = same_base[&key3];
1081+
assert_eq!(flag, 2);
1082+
}
1083+
1084+
let view4 = py
1085+
.eval("array[7:]", None, Some(locals))
1086+
.unwrap()
1087+
.downcast::<PyArray1<f64>>()
1088+
.unwrap();
1089+
1090+
let key4 = BorrowKey::from_array(view4);
1091+
let shared4 = view4.readonly();
1092+
1093+
{
1094+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1095+
assert_eq!(borrow_flags.len(), 1);
1096+
1097+
let same_base = &borrow_flags[&base];
1098+
assert_eq!(same_base.len(), 3);
1099+
1100+
let flag = same_base[&key1];
1101+
assert_eq!(flag, -1);
1102+
1103+
let flag = same_base[&key2];
1104+
assert_eq!(flag, 2);
1105+
1106+
let flag = same_base[&key3];
1107+
assert_eq!(flag, 2);
1108+
1109+
let flag = same_base[&key4];
1110+
assert_eq!(flag, 1);
1111+
}
1112+
1113+
drop(shared2);
1114+
1115+
{
1116+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1117+
assert_eq!(borrow_flags.len(), 1);
1118+
1119+
let same_base = &borrow_flags[&base];
1120+
assert_eq!(same_base.len(), 3);
1121+
1122+
let flag = same_base[&key1];
1123+
assert_eq!(flag, -1);
1124+
1125+
let flag = same_base[&key2];
1126+
assert_eq!(flag, 1);
1127+
1128+
let flag = same_base[&key3];
1129+
assert_eq!(flag, 1);
1130+
1131+
let flag = same_base[&key4];
1132+
assert_eq!(flag, 1);
1133+
}
1134+
1135+
drop(shared3);
1136+
1137+
{
1138+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1139+
assert_eq!(borrow_flags.len(), 1);
1140+
1141+
let same_base = &borrow_flags[&base];
1142+
assert_eq!(same_base.len(), 2);
1143+
1144+
let flag = same_base[&key1];
1145+
assert_eq!(flag, -1);
1146+
1147+
assert!(!same_base.contains_key(&key2));
1148+
1149+
assert!(!same_base.contains_key(&key3));
1150+
1151+
let flag = same_base[&key4];
1152+
assert_eq!(flag, 1);
1153+
}
1154+
1155+
drop(exclusive1);
1156+
1157+
{
1158+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1159+
assert_eq!(borrow_flags.len(), 1);
1160+
1161+
let same_base = &borrow_flags[&base];
1162+
assert_eq!(same_base.len(), 1);
1163+
1164+
assert!(!same_base.contains_key(&key1));
1165+
1166+
assert!(!same_base.contains_key(&key2));
1167+
1168+
assert!(!same_base.contains_key(&key3));
1169+
1170+
let flag = same_base[&key4];
1171+
assert_eq!(flag, 1);
1172+
}
1173+
1174+
drop(shared4);
1175+
1176+
{
1177+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
1178+
assert_eq!(borrow_flags.len(), 0);
1179+
}
1180+
});
1181+
}
9371182
}

tests/borrow.rs

+11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ fn exclusive_and_shared_borrows() {
4343
});
4444
}
4545

46+
#[test]
47+
#[should_panic(expected = "AlreadyBorrowed")]
48+
fn shared_and_exclusive_borrows() {
49+
Python::with_gil(|py| {
50+
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
51+
52+
let _shared = array.readonly();
53+
let _exclusive = array.readwrite();
54+
});
55+
}
56+
4657
#[test]
4758
#[should_panic(expected = "AlreadyBorrowed")]
4859
fn multiple_exclusive_borrows() {

0 commit comments

Comments
 (0)