Skip to content

Commit 544a0fa

Browse files
committed
Fix computation of data range of more-than-one-dimensional arrays.
1 parent f166964 commit 544a0fa

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

src/borrow.rs

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@
161161

162162
use std::cell::UnsafeCell;
163163
use std::collections::hash_map::{Entry, HashMap};
164-
use std::ops::{Deref, Range};
164+
use std::ops::Deref;
165165

166166
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
167167
use num_integer::gcd;
@@ -176,8 +176,11 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
176176

177177
#[derive(PartialEq, Eq, Hash)]
178178
struct BorrowKey {
179-
range: Range<usize>,
179+
/// inclusive range of lowest and highest address covered by array
180+
range: (usize, usize),
181+
/// the data address on which address computations are based
180182
data_ptr: usize,
183+
/// the greatest common divisor of the strides of the array
181184
gcd_strides: isize,
182185
}
183186

@@ -200,10 +203,10 @@ impl BorrowKey {
200203
}
201204

202205
fn conflicts(&self, other: &Self) -> bool {
203-
debug_assert!(self.range.start <= self.range.end);
204-
debug_assert!(other.range.start <= other.range.end);
206+
debug_assert!(self.range.0 <= self.range.1);
207+
debug_assert!(other.range.0 <= other.range.1);
205208

206-
if other.range.start >= self.range.end || other.range.end <= self.range.start {
209+
if other.range.0 > self.range.1 || self.range.0 > other.range.1 {
207210
return false;
208211
}
209212

@@ -646,7 +649,7 @@ fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
646649
}
647650
}
648651

649-
fn data_range<T, D>(array: &PyArray<T, D>) -> Range<usize>
652+
fn data_range<T, D>(array: &PyArray<T, D>) -> (usize, usize)
650653
where
651654
T: Element,
652655
D: Dimension,
@@ -657,7 +660,7 @@ where
657660
let mut end = 0;
658661

659662
for (&dim, &stride) in array.shape().iter().zip(array.strides()) {
660-
let offset = (dim as isize) * stride;
663+
let offset = dim.saturating_sub(1) as isize * stride;
661664

662665
if offset >= 0 {
663666
end += offset;
@@ -669,7 +672,7 @@ where
669672
let start = unsafe { data.offset(start) } as usize;
670673
let end = unsafe { data.offset(end) } as usize;
671674

672-
Range { start, end }
675+
(start, end)
673676
}
674677

675678
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
@@ -713,8 +716,8 @@ mod tests {
713716
assert_eq!(base_address, array as *const _ as usize);
714717

715718
let data_range = data_range(array);
716-
assert_eq!(data_range.start, array.data() as usize);
717-
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
719+
assert_eq!(data_range.0, array.data() as usize);
720+
assert_eq!(data_range.1, unsafe { array.data().add(5) } as usize);
718721
});
719722
}
720723

@@ -731,8 +734,8 @@ mod tests {
731734
assert_eq!(base_address, base as usize);
732735

733736
let data_range = data_range(array);
734-
assert_eq!(data_range.start, array.data() as usize);
735-
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
737+
assert_eq!(data_range.0, array.data() as usize);
738+
assert_eq!(data_range.1, unsafe { array.data().add(5) } as usize);
736739
});
737740
}
738741

@@ -757,8 +760,8 @@ mod tests {
757760
assert_eq!(base_address, base as usize);
758761

759762
let data_range = data_range(view);
760-
assert_eq!(data_range.start, view.data() as usize);
761-
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
763+
assert_eq!(data_range.0, array.data() as usize);
764+
assert_eq!(data_range.1, unsafe { array.data().add(3) } as usize);
762765
});
763766
}
764767

@@ -787,8 +790,8 @@ mod tests {
787790
assert_eq!(base_address, base as usize);
788791

789792
let data_range = data_range(view);
790-
assert_eq!(data_range.start, view.data() as usize);
791-
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
793+
assert_eq!(data_range.0, array.data() as usize);
794+
assert_eq!(data_range.1, unsafe { array.data().add(3) } as usize);
792795
});
793796
}
794797

@@ -826,8 +829,8 @@ mod tests {
826829
assert_eq!(base_address, base as usize);
827830

828831
let data_range = data_range(view2);
829-
assert_eq!(data_range.start, view2.data() as usize);
830-
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
832+
assert_eq!(data_range.0, array.data() as usize);
833+
assert_eq!(data_range.1, array.data() as usize);
831834
});
832835
}
833836

@@ -869,8 +872,8 @@ mod tests {
869872
assert_eq!(base_address, base as usize);
870873

871874
let data_range = data_range(view2);
872-
assert_eq!(data_range.start, view2.data() as usize);
873-
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
875+
assert_eq!(data_range.0, array.data() as usize);
876+
assert_eq!(data_range.1, array.data() as usize);
874877
});
875878
}
876879

@@ -895,8 +898,9 @@ mod tests {
895898
assert_eq!(base_address, base as usize);
896899

897900
let data_range = data_range(view);
898-
assert_eq!(data_range.start, unsafe { view.data().offset(-9) } as usize);
899-
assert_eq!(data_range.end, unsafe { view.data().offset(6) } as usize);
901+
assert_eq!(view.data(), unsafe { array.data().offset(2) });
902+
assert_eq!(data_range.0, unsafe { view.data().offset(-2) } as usize);
903+
assert_eq!(data_range.1, unsafe { view.data().offset(3) } as usize);
900904
});
901905
}
902906

0 commit comments

Comments
 (0)