Skip to content

Commit ff89e2b

Browse files
committed
added generate::random_with_rank function and updated test to use it
1 parent 970232c commit ff89e2b

File tree

2 files changed

+66
-66
lines changed

2 files changed

+66
-66
lines changed

ndarray-linalg/src/generate.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
//! Generator functions for matrices
22
3+
use ndarray::linalg::general_mat_mul;
34
use ndarray::*;
45
use rand::prelude::*;
56

67
use super::convert::*;
78
use super::error::*;
89
use super::qr::*;
10+
use super::rank::Rank;
911
use super::types::*;
12+
use super::Scalar;
1013

1114
/// Hermite conjugate matrix
1215
pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
@@ -34,6 +37,53 @@ where
3437
ArrayBase::from_shape_fn(sh, |_| A::rand(&mut rng))
3538
}
3639

40+
/// Generate random array with a given rank
41+
///
42+
/// The rank must be less then or equal to the smallest dimension of array
43+
pub fn random_with_rank<A, Sh>(shape: Sh, rank: usize) -> Array2<A>
44+
where
45+
A: Scalar + Lapack,
46+
Sh: ShapeBuilder<Dim = Ix2> + Clone,
47+
{
48+
// handle zero-rank case
49+
if rank == 0 {
50+
return Array2::zeros(shape);
51+
}
52+
53+
let (n, m) = shape.clone().into_shape().raw_dim().into_pattern();
54+
let min_dim = usize::min(n, m);
55+
assert!(rank <= min_dim);
56+
57+
for _ in 0..10 {
58+
// handle full-rank case
59+
let out = if rank == min_dim {
60+
random(shape.clone())
61+
62+
// handle partial-rank case
63+
} else {
64+
// multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
65+
// produce `an m × n` array with rank `r`
66+
// https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
67+
let mut out = Array2::zeros(shape.clone());
68+
let left: Array2<A> = random([out.nrows(), rank]);
69+
let right: Array2<A> = random([rank, out.ncols()]);
70+
general_mat_mul(A::one(), &left, &right, A::zero(), &mut out);
71+
out
72+
};
73+
74+
// check rank
75+
if let Ok(out_rank) = out.rank() {
76+
if out_rank == rank {
77+
return out;
78+
}
79+
}
80+
}
81+
82+
unreachable!(
83+
"Failed to generate random matrix of desired rank within 10 tries. This is very unlikely."
84+
);
85+
}
86+
3787
/// Generate random unitary matrix using QR decomposition
3888
///
3989
/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.

ndarray-linalg/tests/pinv.rs

Lines changed: 16 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,40 @@
11
use ndarray::arr2;
22
use ndarray::*;
3-
use ndarray_linalg::rank::Rank;
43
use ndarray_linalg::*;
5-
use rand::{seq::SliceRandom, thread_rng};
4+
use rand::{thread_rng, Rng};
65

7-
/// creates a zero matrix which always has rank zero
8-
pub fn zero_rank<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D>
6+
/// create a zero rank array
7+
pub fn zero_rank<A, Sh>(sh: Sh) -> Array2<A>
98
where
10-
A: Scalar,
11-
S: DataOwned<Elem = A>,
12-
D: Dimension,
13-
Sh: ShapeBuilder<Dim = D>,
9+
A: Scalar + Lapack,
10+
Sh: ShapeBuilder<Dim = Ix2> + Clone,
1411
{
15-
ArrayBase::zeros(sh)
12+
random_with_rank(sh, 0)
1613
}
1714

18-
/// creates a random matrix and repeatedly creates a linear dependency between rows until the
19-
/// rank drops.
15+
/// create a random matrix with a random partial rank.
2016
pub fn partial_rank<A, Sh>(sh: Sh) -> Array2<A>
2117
where
2218
A: Scalar + Lapack,
23-
Sh: ShapeBuilder<Dim = Ix2>,
19+
Sh: ShapeBuilder<Dim = Ix2> + Clone,
2420
{
2521
let mut rng = thread_rng();
26-
let mut result: Array2<A> = random(sh);
27-
println!("before: {:?}", result);
28-
29-
let (n, m) = result.dim();
30-
println!("(n, m) => ({:?},{:?})", n, m);
31-
32-
// create randomized row iterator
22+
let (m, n) = sh.clone().into_shape().raw_dim().into_pattern();
3323
let min_dim = n.min(m);
34-
let mut row_indexes = (0..min_dim).into_iter().collect::<Vec<usize>>();
35-
row_indexes.as_mut_slice().shuffle(&mut rng);
36-
let mut row_index_iter = row_indexes.iter().cycle();
37-
38-
for count in 1..=10 {
39-
println!("count: {}", count);
40-
let (&x, &y) = (
41-
row_index_iter.next().unwrap(),
42-
row_index_iter.next().unwrap(),
43-
);
44-
let (from_row_index, to_row_index) = if x < y { (x, y) } else { (y, x) };
45-
println!("(r_f, r_t) => ({:?},{:?})", from_row_index, to_row_index);
46-
47-
let mut it = result.outer_iter_mut();
48-
let from_row = it.nth(from_row_index).unwrap();
49-
let mut to_row = it.nth(to_row_index - (from_row_index + 1)).unwrap();
50-
51-
// set the to_row with the value of the from_row multiplied by rand_multiple
52-
let rand_multiple = A::rand(&mut rng);
53-
println!("rand_multiple: {:?}", rand_multiple);
54-
Zip::from(&mut to_row)
55-
.and(&from_row)
56-
.for_each(|r1, r2| *r1 = *r2 * rand_multiple);
57-
58-
if let Ok(rank) = result.rank() {
59-
println!("result: {:?}", result);
60-
println!("rank: {:?}", rank);
61-
if rank > 0 && rank < min_dim {
62-
return result;
63-
}
64-
}
65-
}
66-
unreachable!("unable to generate random partial rank matrix after making 10 mutations")
24+
let rank = rng.gen_range(1..min_dim);
25+
println!("desired rank = {}", rank);
26+
random_with_rank(sh, rank)
6727
}
6828

69-
/// creates a random matrix and insures it is full rank.
29+
/// create a random matrix and ensures it is full rank.
7030
pub fn full_rank<A, Sh>(sh: Sh) -> Array2<A>
7131
where
7232
A: Scalar + Lapack,
7333
Sh: ShapeBuilder<Dim = Ix2> + Clone,
7434
{
75-
for _ in 0..10 {
76-
let r: Array2<A> = random(sh.clone());
77-
let (n, m) = r.dim();
78-
let n = n.min(m);
79-
if let Ok(rank) = r.rank() {
80-
println!("result: {:?}", r);
81-
println!("rank: {:?}", rank);
82-
if rank == n {
83-
return r;
84-
}
85-
}
86-
}
87-
unreachable!("unable to generate random full rank matrix in 10 tries")
35+
let (m, n) = sh.clone().into_shape().raw_dim().into_pattern();
36+
let min_dim = n.min(m);
37+
random_with_rank(sh, min_dim)
8838
}
8939

9040
fn test<T: Scalar + Lapack>(a: &Array2<T>, tolerance: T::Real) {

0 commit comments

Comments
 (0)