Skip to content

Commit ee084db

Browse files
committed
Extend integration tests to include the scalar cases for dot and inner.
1 parent 00b035d commit ee084db

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

tests/sum_products.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ fn test_dot() {
1212
let a = pyarray![py, 1, 2, 3];
1313
let err = dot::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
1414
assert!(err.to_string().contains("not aligned"), "{}", err);
15+
16+
let a = pyarray![py, 1, 2, 3];
17+
let b = pyarray![py, 0, 1, 0];
18+
let c: &PyArray0<_> = dot(a, b).unwrap();
19+
assert_eq!(c.item(), 2);
20+
let c: i32 = dot(a, b).unwrap();
21+
assert_eq!(c, 2);
22+
23+
let a = pyarray![py, 1.0, 2.0, 3.0];
24+
let b = pyarray![py, 0.0, 0.0, 0.0];
25+
let c: f64 = dot(a, b).unwrap();
26+
assert_eq!(c, 0.0);
1527
});
1628
}
1729

@@ -21,7 +33,14 @@ fn test_inner() {
2133
let a = pyarray![py, 1, 2, 3];
2234
let b = pyarray![py, 0, 1, 0];
2335
let c: &PyArray0<_> = inner(a, b).unwrap();
24-
assert_eq!(c.readonly().as_array(), ndarray::arr0(2));
36+
assert_eq!(c.item(), 2);
37+
let c: i32 = inner(a, b).unwrap();
38+
assert_eq!(c, 2);
39+
40+
let a = pyarray![py, 1.0, 2.0, 3.0];
41+
let b = pyarray![py, 0.0, 0.0, 0.0];
42+
let c: f64 = inner(a, b).unwrap();
43+
assert_eq!(c, 0.0);
2544

2645
let a = pyarray![py, [1, 0], [0, 1]];
2746
let b = pyarray![py, [4, 1], [2, 2]];
@@ -43,10 +62,7 @@ fn test_einsum() {
4362
let b = pyarray![py, 0, 1, 2, 3, 4];
4463
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
4564

46-
assert_eq!(
47-
einsum!("ii", a).unwrap().readonly().as_array(),
48-
ndarray::arr0(60)
49-
);
65+
assert_eq!(einsum!("ii", a).unwrap().item(), 60);
5066
assert_eq!(
5167
einsum!("ii->i", a).unwrap().readonly().as_array(),
5268
array![0, 6, 12, 18, 24],

0 commit comments

Comments
 (0)