|
86 | 86 | //! });
|
87 | 87 | //! ```
|
88 | 88 | //!
|
| 89 | +//! The third example shows that some views are incorrectly rejected since the borrows are over-approximated. |
| 90 | +//! |
| 91 | +//! ```rust |
| 92 | +//! # use std::panic::{catch_unwind, AssertUnwindSafe}; |
| 93 | +//! # |
| 94 | +//! use numpy::PyArray2; |
| 95 | +//! use pyo3::{types::IntoPyDict, Python}; |
| 96 | +//! |
| 97 | +//! Python::with_gil(|py| { |
| 98 | +//! let array = PyArray2::<f64>::zeros(py, (10, 10), false); |
| 99 | +//! let locals = [("array", array)].into_py_dict(py); |
| 100 | +//! |
| 101 | +//! let view1 = py.eval("array[:, ::3]", None, Some(locals)).unwrap().downcast::<PyArray2<f64>>().unwrap(); |
| 102 | +//! let view2 = py.eval("array[:, 1::3]", None, Some(locals)).unwrap().downcast::<PyArray2<f64>>().unwrap(); |
| 103 | +//! |
| 104 | +//! // A false conflict as the views do not actually share any elements. |
| 105 | +//! let res = catch_unwind(AssertUnwindSafe(|| { |
| 106 | +//! let _view1 = view1.readwrite(); |
| 107 | +//! let _view2 = view2.readwrite(); |
| 108 | +//! })); |
| 109 | +//! assert!(res.is_err()); |
| 110 | +//! }); |
| 111 | +//! ``` |
| 112 | +//! |
89 | 113 | //! # Rationale
|
90 | 114 | //!
|
91 | 115 | //! Rust references require aliasing discipline to be maintained, i.e. there must always
|
@@ -235,6 +259,9 @@ impl BorrowFlags {
|
235 | 259 | let same_base = entry.into_mut();
|
236 | 260 |
|
237 | 261 | if let Some(readers) = same_base.get_mut(&key) {
|
| 262 | + // Zero flags are removed during release. |
| 263 | + assert_ne!(*readers, 0); |
| 264 | + |
238 | 265 | let new_readers = readers.wrapping_add(1);
|
239 | 266 |
|
240 | 267 | if new_readers <= 0 {
|
@@ -309,12 +336,11 @@ impl BorrowFlags {
|
309 | 336 | let same_base = entry.into_mut();
|
310 | 337 |
|
311 | 338 | if let Some(writers) = same_base.get_mut(&key) {
|
312 |
| - if *writers != 0 { |
313 |
| - cold(); |
314 |
| - return Err(BorrowError::AlreadyBorrowed); |
315 |
| - } |
| 339 | + // Zero flags are removed during release. |
| 340 | + assert_ne!(*writers, 0); |
316 | 341 |
|
317 |
| - *writers = -1; |
| 342 | + cold(); |
| 343 | + return Err(BorrowError::AlreadyBorrowed); |
318 | 344 | } else {
|
319 | 345 | if same_base
|
320 | 346 | .iter()
|
@@ -620,8 +646,6 @@ fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
|
620 | 646 | }
|
621 | 647 | }
|
622 | 648 |
|
623 |
| -// FIXME(adamreichold): This is a coarse approximation and needs to be refined, |
624 |
| -// i.e. borrows of interleaved views into the same base should not be considered conflicting. |
625 | 649 | fn data_range<T, D>(array: &PyArray<T, D>) -> Range<usize>
|
626 | 650 | where
|
627 | 651 | T: Element,
|
@@ -934,4 +958,225 @@ mod tests {
|
934 | 958 | assert!(key1.conflicts(&key2));
|
935 | 959 | });
|
936 | 960 | }
|
| 961 | + |
| 962 | + #[test] |
| 963 | + fn borrow_multiple_arrays() { |
| 964 | + Python::with_gil(|py| { |
| 965 | + let array1 = PyArray::<f64, _>::zeros(py, 10, false); |
| 966 | + let array2 = PyArray::<f64, _>::zeros(py, 10, false); |
| 967 | + |
| 968 | + let base1 = base_address(array1); |
| 969 | + let base2 = base_address(array2); |
| 970 | + |
| 971 | + let key1 = BorrowKey::from_array(array1); |
| 972 | + let _exclusive1 = array1.readwrite(); |
| 973 | + |
| 974 | + { |
| 975 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 976 | + assert_eq!(borrow_flags.len(), 1); |
| 977 | + |
| 978 | + let same_base = &borrow_flags[&base1]; |
| 979 | + assert_eq!(same_base.len(), 1); |
| 980 | + |
| 981 | + let flag = same_base[&key1]; |
| 982 | + assert_eq!(flag, -1); |
| 983 | + } |
| 984 | + |
| 985 | + let key2 = BorrowKey::from_array(array2); |
| 986 | + let _shared2 = array2.readonly(); |
| 987 | + |
| 988 | + { |
| 989 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 990 | + assert_eq!(borrow_flags.len(), 2); |
| 991 | + |
| 992 | + let same_base = &borrow_flags[&base1]; |
| 993 | + assert_eq!(same_base.len(), 1); |
| 994 | + |
| 995 | + let flag = same_base[&key1]; |
| 996 | + assert_eq!(flag, -1); |
| 997 | + |
| 998 | + let same_base = &borrow_flags[&base2]; |
| 999 | + assert_eq!(same_base.len(), 1); |
| 1000 | + |
| 1001 | + let flag = same_base[&key2]; |
| 1002 | + assert_eq!(flag, 1); |
| 1003 | + } |
| 1004 | + }); |
| 1005 | + } |
| 1006 | + |
| 1007 | + #[test] |
| 1008 | + fn borrow_multiple_views() { |
| 1009 | + Python::with_gil(|py| { |
| 1010 | + let array = PyArray::<f64, _>::zeros(py, 10, false); |
| 1011 | + let base = base_address(array); |
| 1012 | + |
| 1013 | + let locals = [("array", array)].into_py_dict(py); |
| 1014 | + |
| 1015 | + let view1 = py |
| 1016 | + .eval("array[:5]", None, Some(locals)) |
| 1017 | + .unwrap() |
| 1018 | + .downcast::<PyArray1<f64>>() |
| 1019 | + .unwrap(); |
| 1020 | + |
| 1021 | + let key1 = BorrowKey::from_array(view1); |
| 1022 | + let exclusive1 = view1.readwrite(); |
| 1023 | + |
| 1024 | + { |
| 1025 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1026 | + assert_eq!(borrow_flags.len(), 1); |
| 1027 | + |
| 1028 | + let same_base = &borrow_flags[&base]; |
| 1029 | + assert_eq!(same_base.len(), 1); |
| 1030 | + |
| 1031 | + let flag = same_base[&key1]; |
| 1032 | + assert_eq!(flag, -1); |
| 1033 | + } |
| 1034 | + |
| 1035 | + let view2 = py |
| 1036 | + .eval("array[5:]", None, Some(locals)) |
| 1037 | + .unwrap() |
| 1038 | + .downcast::<PyArray1<f64>>() |
| 1039 | + .unwrap(); |
| 1040 | + |
| 1041 | + let key2 = BorrowKey::from_array(view2); |
| 1042 | + let shared2 = view2.readonly(); |
| 1043 | + |
| 1044 | + { |
| 1045 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1046 | + assert_eq!(borrow_flags.len(), 1); |
| 1047 | + |
| 1048 | + let same_base = &borrow_flags[&base]; |
| 1049 | + assert_eq!(same_base.len(), 2); |
| 1050 | + |
| 1051 | + let flag = same_base[&key1]; |
| 1052 | + assert_eq!(flag, -1); |
| 1053 | + |
| 1054 | + let flag = same_base[&key2]; |
| 1055 | + assert_eq!(flag, 1); |
| 1056 | + } |
| 1057 | + |
| 1058 | + let view3 = py |
| 1059 | + .eval("array[5:]", None, Some(locals)) |
| 1060 | + .unwrap() |
| 1061 | + .downcast::<PyArray1<f64>>() |
| 1062 | + .unwrap(); |
| 1063 | + |
| 1064 | + let key3 = BorrowKey::from_array(view3); |
| 1065 | + let shared3 = view3.readonly(); |
| 1066 | + |
| 1067 | + { |
| 1068 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1069 | + assert_eq!(borrow_flags.len(), 1); |
| 1070 | + |
| 1071 | + let same_base = &borrow_flags[&base]; |
| 1072 | + assert_eq!(same_base.len(), 2); |
| 1073 | + |
| 1074 | + let flag = same_base[&key1]; |
| 1075 | + assert_eq!(flag, -1); |
| 1076 | + |
| 1077 | + let flag = same_base[&key2]; |
| 1078 | + assert_eq!(flag, 2); |
| 1079 | + |
| 1080 | + let flag = same_base[&key3]; |
| 1081 | + assert_eq!(flag, 2); |
| 1082 | + } |
| 1083 | + |
| 1084 | + let view4 = py |
| 1085 | + .eval("array[7:]", None, Some(locals)) |
| 1086 | + .unwrap() |
| 1087 | + .downcast::<PyArray1<f64>>() |
| 1088 | + .unwrap(); |
| 1089 | + |
| 1090 | + let key4 = BorrowKey::from_array(view4); |
| 1091 | + let shared4 = view4.readonly(); |
| 1092 | + |
| 1093 | + { |
| 1094 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1095 | + assert_eq!(borrow_flags.len(), 1); |
| 1096 | + |
| 1097 | + let same_base = &borrow_flags[&base]; |
| 1098 | + assert_eq!(same_base.len(), 3); |
| 1099 | + |
| 1100 | + let flag = same_base[&key1]; |
| 1101 | + assert_eq!(flag, -1); |
| 1102 | + |
| 1103 | + let flag = same_base[&key2]; |
| 1104 | + assert_eq!(flag, 2); |
| 1105 | + |
| 1106 | + let flag = same_base[&key3]; |
| 1107 | + assert_eq!(flag, 2); |
| 1108 | + |
| 1109 | + let flag = same_base[&key4]; |
| 1110 | + assert_eq!(flag, 1); |
| 1111 | + } |
| 1112 | + |
| 1113 | + drop(shared2); |
| 1114 | + |
| 1115 | + { |
| 1116 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1117 | + assert_eq!(borrow_flags.len(), 1); |
| 1118 | + |
| 1119 | + let same_base = &borrow_flags[&base]; |
| 1120 | + assert_eq!(same_base.len(), 3); |
| 1121 | + |
| 1122 | + let flag = same_base[&key1]; |
| 1123 | + assert_eq!(flag, -1); |
| 1124 | + |
| 1125 | + let flag = same_base[&key2]; |
| 1126 | + assert_eq!(flag, 1); |
| 1127 | + |
| 1128 | + let flag = same_base[&key3]; |
| 1129 | + assert_eq!(flag, 1); |
| 1130 | + |
| 1131 | + let flag = same_base[&key4]; |
| 1132 | + assert_eq!(flag, 1); |
| 1133 | + } |
| 1134 | + |
| 1135 | + drop(shared3); |
| 1136 | + |
| 1137 | + { |
| 1138 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1139 | + assert_eq!(borrow_flags.len(), 1); |
| 1140 | + |
| 1141 | + let same_base = &borrow_flags[&base]; |
| 1142 | + assert_eq!(same_base.len(), 2); |
| 1143 | + |
| 1144 | + let flag = same_base[&key1]; |
| 1145 | + assert_eq!(flag, -1); |
| 1146 | + |
| 1147 | + assert!(!same_base.contains_key(&key2)); |
| 1148 | + |
| 1149 | + assert!(!same_base.contains_key(&key3)); |
| 1150 | + |
| 1151 | + let flag = same_base[&key4]; |
| 1152 | + assert_eq!(flag, 1); |
| 1153 | + } |
| 1154 | + |
| 1155 | + drop(exclusive1); |
| 1156 | + |
| 1157 | + { |
| 1158 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1159 | + assert_eq!(borrow_flags.len(), 1); |
| 1160 | + |
| 1161 | + let same_base = &borrow_flags[&base]; |
| 1162 | + assert_eq!(same_base.len(), 1); |
| 1163 | + |
| 1164 | + assert!(!same_base.contains_key(&key1)); |
| 1165 | + |
| 1166 | + assert!(!same_base.contains_key(&key2)); |
| 1167 | + |
| 1168 | + assert!(!same_base.contains_key(&key3)); |
| 1169 | + |
| 1170 | + let flag = same_base[&key4]; |
| 1171 | + assert_eq!(flag, 1); |
| 1172 | + } |
| 1173 | + |
| 1174 | + drop(shared4); |
| 1175 | + |
| 1176 | + { |
| 1177 | + let borrow_flags = unsafe { BORROW_FLAGS.get() }; |
| 1178 | + assert_eq!(borrow_flags.len(), 0); |
| 1179 | + } |
| 1180 | + }); |
| 1181 | + } |
937 | 1182 | }
|
0 commit comments