|
1 | 1 | use ndarray::arr2;
|
2 | 2 | use ndarray::*;
|
3 |
| -use ndarray_linalg::rank::Rank; |
4 | 3 | use ndarray_linalg::*;
|
5 |
| -use rand::{seq::SliceRandom, thread_rng}; |
| 4 | +use rand::{thread_rng, Rng}; |
6 | 5 |
|
7 |
| -/// creates a zero matrix which always has rank zero |
8 |
| -pub fn zero_rank<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D> |
| 6 | +/// create a zero rank array |
| 7 | +pub fn zero_rank<A, Sh>(sh: Sh) -> Array2<A> |
9 | 8 | where
|
10 |
| - A: Scalar, |
11 |
| - S: DataOwned<Elem = A>, |
12 |
| - D: Dimension, |
13 |
| - Sh: ShapeBuilder<Dim = D>, |
| 9 | + A: Scalar + Lapack, |
| 10 | + Sh: ShapeBuilder<Dim = Ix2> + Clone, |
14 | 11 | {
|
15 |
| - ArrayBase::zeros(sh) |
| 12 | + random_with_rank(sh, 0) |
16 | 13 | }
|
17 | 14 |
|
18 |
| -/// creates a random matrix and repeatedly creates a linear dependency between rows until the |
19 |
| -/// rank drops. |
| 15 | +/// create a random matrix with a random partial rank. |
20 | 16 | pub fn partial_rank<A, Sh>(sh: Sh) -> Array2<A>
|
21 | 17 | where
|
22 | 18 | A: Scalar + Lapack,
|
23 |
| - Sh: ShapeBuilder<Dim = Ix2>, |
| 19 | + Sh: ShapeBuilder<Dim = Ix2> + Clone, |
24 | 20 | {
|
25 | 21 | let mut rng = thread_rng();
|
26 |
| - let mut result: Array2<A> = random(sh); |
27 |
| - println!("before: {:?}", result); |
28 |
| - |
29 |
| - let (n, m) = result.dim(); |
30 |
| - println!("(n, m) => ({:?},{:?})", n, m); |
31 |
| - |
32 |
| - // create randomized row iterator |
| 22 | + let (m, n) = sh.clone().into_shape().raw_dim().into_pattern(); |
33 | 23 | let min_dim = n.min(m);
|
34 |
| - let mut row_indexes = (0..min_dim).into_iter().collect::<Vec<usize>>(); |
35 |
| - row_indexes.as_mut_slice().shuffle(&mut rng); |
36 |
| - let mut row_index_iter = row_indexes.iter().cycle(); |
37 |
| - |
38 |
| - for count in 1..=10 { |
39 |
| - println!("count: {}", count); |
40 |
| - let (&x, &y) = ( |
41 |
| - row_index_iter.next().unwrap(), |
42 |
| - row_index_iter.next().unwrap(), |
43 |
| - ); |
44 |
| - let (from_row_index, to_row_index) = if x < y { (x, y) } else { (y, x) }; |
45 |
| - println!("(r_f, r_t) => ({:?},{:?})", from_row_index, to_row_index); |
46 |
| - |
47 |
| - let mut it = result.outer_iter_mut(); |
48 |
| - let from_row = it.nth(from_row_index).unwrap(); |
49 |
| - let mut to_row = it.nth(to_row_index - (from_row_index + 1)).unwrap(); |
50 |
| - |
51 |
| - // set the to_row with the value of the from_row multiplied by rand_multiple |
52 |
| - let rand_multiple = A::rand(&mut rng); |
53 |
| - println!("rand_multiple: {:?}", rand_multiple); |
54 |
| - Zip::from(&mut to_row) |
55 |
| - .and(&from_row) |
56 |
| - .for_each(|r1, r2| *r1 = *r2 * rand_multiple); |
57 |
| - |
58 |
| - if let Ok(rank) = result.rank() { |
59 |
| - println!("result: {:?}", result); |
60 |
| - println!("rank: {:?}", rank); |
61 |
| - if rank > 0 && rank < min_dim { |
62 |
| - return result; |
63 |
| - } |
64 |
| - } |
65 |
| - } |
66 |
| - unreachable!("unable to generate random partial rank matrix after making 10 mutations") |
| 24 | + let rank = rng.gen_range(1..min_dim); |
| 25 | + println!("desired rank = {}", rank); |
| 26 | + random_with_rank(sh, rank) |
67 | 27 | }
|
68 | 28 |
|
69 |
| -/// creates a random matrix and insures it is full rank. |
| 29 | +/// create a random matrix and ensures it is full rank. |
70 | 30 | pub fn full_rank<A, Sh>(sh: Sh) -> Array2<A>
|
71 | 31 | where
|
72 | 32 | A: Scalar + Lapack,
|
73 | 33 | Sh: ShapeBuilder<Dim = Ix2> + Clone,
|
74 | 34 | {
|
75 |
| - for _ in 0..10 { |
76 |
| - let r: Array2<A> = random(sh.clone()); |
77 |
| - let (n, m) = r.dim(); |
78 |
| - let n = n.min(m); |
79 |
| - if let Ok(rank) = r.rank() { |
80 |
| - println!("result: {:?}", r); |
81 |
| - println!("rank: {:?}", rank); |
82 |
| - if rank == n { |
83 |
| - return r; |
84 |
| - } |
85 |
| - } |
86 |
| - } |
87 |
| - unreachable!("unable to generate random full rank matrix in 10 tries") |
| 35 | + let (m, n) = sh.clone().into_shape().raw_dim().into_pattern(); |
| 36 | + let min_dim = n.min(m); |
| 37 | + random_with_rank(sh, min_dim) |
88 | 38 | }
|
89 | 39 |
|
90 | 40 | fn test<T: Scalar + Lapack>(a: &Array2<T>, tolerance: T::Real) {
|
|
0 commit comments