161
161
162
162
use std:: cell:: UnsafeCell ;
163
163
use std:: collections:: hash_map:: { Entry , HashMap } ;
164
- use std:: ops:: { Deref , Range } ;
164
+ use std:: ops:: Deref ;
165
165
166
166
use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
167
167
use num_integer:: gcd;
@@ -176,8 +176,11 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
176
176
177
177
#[ derive( PartialEq , Eq , Hash ) ]
178
178
struct BorrowKey {
179
- range : Range < usize > ,
179
+ /// inclusive range of lowest and highest address covered by array
180
+ range : ( usize , usize ) ,
181
+ /// the data address on which address computations are based
180
182
data_ptr : usize ,
183
+ /// the greatest common divisor of the strides of the array
181
184
gcd_strides : isize ,
182
185
}
183
186
@@ -200,10 +203,10 @@ impl BorrowKey {
200
203
}
201
204
202
205
fn conflicts ( & self , other : & Self ) -> bool {
203
- debug_assert ! ( self . range. start <= self . range. end ) ;
204
- debug_assert ! ( other. range. start <= other. range. end ) ;
206
+ debug_assert ! ( self . range. 0 <= self . range. 1 ) ;
207
+ debug_assert ! ( other. range. 0 <= other. range. 1 ) ;
205
208
206
- if other. range . start >= self . range . end || other . range . end <= self . range . start {
209
+ if other. range . 0 > self . range . 1 || self . range . 0 > other . range . 1 {
207
210
return false ;
208
211
}
209
212
@@ -646,7 +649,7 @@ fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
646
649
}
647
650
}
648
651
649
- fn data_range < T , D > ( array : & PyArray < T , D > ) -> Range < usize >
652
+ fn data_range < T , D > ( array : & PyArray < T , D > ) -> ( usize , usize )
650
653
where
651
654
T : Element ,
652
655
D : Dimension ,
@@ -657,7 +660,7 @@ where
657
660
let mut end = 0 ;
658
661
659
662
for ( & dim, & stride) in array. shape ( ) . iter ( ) . zip ( array. strides ( ) ) {
660
- let offset = ( dim as isize ) * stride;
663
+ let offset = dim. saturating_sub ( 1 ) as isize * stride;
661
664
662
665
if offset >= 0 {
663
666
end += offset;
@@ -669,7 +672,7 @@ where
669
672
let start = unsafe { data. offset ( start) } as usize ;
670
673
let end = unsafe { data. offset ( end) } as usize ;
671
674
672
- Range { start, end }
675
+ ( start, end)
673
676
}
674
677
675
678
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
@@ -713,8 +716,8 @@ mod tests {
713
716
assert_eq ! ( base_address, array as * const _ as usize ) ;
714
717
715
718
let data_range = data_range ( array) ;
716
- assert_eq ! ( data_range. start , array. data( ) as usize ) ;
717
- assert_eq ! ( data_range. end , unsafe { array. data( ) . add( 15 ) } as usize ) ;
719
+ assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
720
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 5 ) } as usize ) ;
718
721
} ) ;
719
722
}
720
723
@@ -731,8 +734,8 @@ mod tests {
731
734
assert_eq ! ( base_address, base as usize ) ;
732
735
733
736
let data_range = data_range ( array) ;
734
- assert_eq ! ( data_range. start , array. data( ) as usize ) ;
735
- assert_eq ! ( data_range. end , unsafe { array. data( ) . add( 15 ) } as usize ) ;
737
+ assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
738
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 5 ) } as usize ) ;
736
739
} ) ;
737
740
}
738
741
@@ -757,8 +760,8 @@ mod tests {
757
760
assert_eq ! ( base_address, base as usize ) ;
758
761
759
762
let data_range = data_range ( view) ;
760
- assert_eq ! ( data_range. start , view . data( ) as usize ) ;
761
- assert_eq ! ( data_range. end , unsafe { view . data( ) . add( 12 ) } as usize ) ;
763
+ assert_eq ! ( data_range. 0 , array . data( ) as usize ) ;
764
+ assert_eq ! ( data_range. 1 , unsafe { array . data( ) . add( 3 ) } as usize ) ;
762
765
} ) ;
763
766
}
764
767
@@ -787,8 +790,8 @@ mod tests {
787
790
assert_eq ! ( base_address, base as usize ) ;
788
791
789
792
let data_range = data_range ( view) ;
790
- assert_eq ! ( data_range. start , view . data( ) as usize ) ;
791
- assert_eq ! ( data_range. end , unsafe { view . data( ) . add( 12 ) } as usize ) ;
793
+ assert_eq ! ( data_range. 0 , array . data( ) as usize ) ;
794
+ assert_eq ! ( data_range. 1 , unsafe { array . data( ) . add( 3 ) } as usize ) ;
792
795
} ) ;
793
796
}
794
797
@@ -826,8 +829,8 @@ mod tests {
826
829
assert_eq ! ( base_address, base as usize ) ;
827
830
828
831
let data_range = data_range ( view2) ;
829
- assert_eq ! ( data_range. start , view2 . data( ) as usize ) ;
830
- assert_eq ! ( data_range. end , unsafe { view2 . data( ) . add ( 6 ) } as usize ) ;
832
+ assert_eq ! ( data_range. 0 , array . data( ) as usize ) ;
833
+ assert_eq ! ( data_range. 1 , array . data( ) as usize ) ;
831
834
} ) ;
832
835
}
833
836
@@ -869,8 +872,8 @@ mod tests {
869
872
assert_eq ! ( base_address, base as usize ) ;
870
873
871
874
let data_range = data_range ( view2) ;
872
- assert_eq ! ( data_range. start , view2 . data( ) as usize ) ;
873
- assert_eq ! ( data_range. end , unsafe { view2 . data( ) . add ( 6 ) } as usize ) ;
875
+ assert_eq ! ( data_range. 0 , array . data( ) as usize ) ;
876
+ assert_eq ! ( data_range. 1 , array . data( ) as usize ) ;
874
877
} ) ;
875
878
}
876
879
@@ -895,8 +898,9 @@ mod tests {
895
898
assert_eq ! ( base_address, base as usize ) ;
896
899
897
900
let data_range = data_range ( view) ;
898
- assert_eq ! ( data_range. start, unsafe { view. data( ) . offset( -9 ) } as usize ) ;
899
- assert_eq ! ( data_range. end, unsafe { view. data( ) . offset( 6 ) } as usize ) ;
901
+ assert_eq ! ( view. data( ) , unsafe { array. data( ) . offset( 2 ) } ) ;
902
+ assert_eq ! ( data_range. 0 , unsafe { view. data( ) . offset( -2 ) } as usize ) ;
903
+ assert_eq ! ( data_range. 1 , unsafe { view. data( ) . offset( 3 ) } as usize ) ;
900
904
} ) ;
901
905
}
902
906
0 commit comments