@@ -54,28 +54,44 @@ where
54
54
let min_dim = usize:: min ( n, m) ;
55
55
assert ! ( rank <= min_dim) ;
56
56
57
- for _ in 0 ..10 {
58
- // handle full-rank case
59
- let out = if rank == min_dim {
60
- random ( shape. clone ( ) )
61
-
62
- // handle partial-rank case
63
- } else {
64
- // multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
65
- // produce `an m × n` array with rank `r`
66
- // https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
67
- let mut out = Array2 :: zeros ( shape. clone ( ) ) ;
68
- let left: Array2 < A > = random ( [ out. nrows ( ) , rank] ) ;
69
- let right: Array2 < A > = random ( [ rank, out. ncols ( ) ] ) ;
57
+ let mut rng = thread_rng ( ) ;
58
+
59
+ // handle full-rank case
60
+ if rank == min_dim {
61
+ let mut out = random ( shape) ;
62
+ for _ in 0 ..10 {
63
+ // check rank
64
+ if let Ok ( out_rank) = out. rank ( ) {
65
+ if out_rank == rank {
66
+ return out;
67
+ }
68
+ }
69
+
70
+ out. mapv_inplace ( |_| A :: rand ( & mut rng) ) ;
71
+ }
72
+
73
+ // handle partial-rank case
74
+ //
75
+ // multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
76
+ // produce `an m × n` array with rank `r`
77
+ // https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
78
+ } else {
79
+ let mut out = Array2 :: zeros ( shape) ;
80
+ let mut left: Array2 < A > = random ( [ out. nrows ( ) , rank] ) ;
81
+ let mut right: Array2 < A > = random ( [ rank, out. ncols ( ) ] ) ;
82
+
83
+ for _ in 0 ..10 {
70
84
general_mat_mul ( A :: one ( ) , & left, & right, A :: zero ( ) , & mut out) ;
71
- out
72
- } ;
73
85
74
- // check rank
75
- if let Ok ( out_rank) = out. rank ( ) {
76
- if out_rank == rank {
77
- return out;
86
+ // check rank
87
+ if let Ok ( out_rank) = out. rank ( ) {
88
+ if out_rank == rank {
89
+ return out;
90
+ }
78
91
}
92
+
93
+ left. mapv_inplace ( |_| A :: rand ( & mut rng) ) ;
94
+ right. mapv_inplace ( |_| A :: rand ( & mut rng) ) ;
79
95
}
80
96
}
81
97
0 commit comments