Skip to content

Commit 942b7d7

Browse files
committed
avoid reallocating arrays in loop
1 parent ff89e2b commit 942b7d7

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

ndarray-linalg/src/generate.rs

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,44 @@ where
5454
let min_dim = usize::min(n, m);
5555
assert!(rank <= min_dim);
5656

57-
for _ in 0..10 {
58-
// handle full-rank case
59-
let out = if rank == min_dim {
60-
random(shape.clone())
61-
62-
// handle partial-rank case
63-
} else {
64-
// multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
65-
// produce `an m × n` array with rank `r`
66-
// https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
67-
let mut out = Array2::zeros(shape.clone());
68-
let left: Array2<A> = random([out.nrows(), rank]);
69-
let right: Array2<A> = random([rank, out.ncols()]);
57+
let mut rng = thread_rng();
58+
59+
// handle full-rank case
60+
if rank == min_dim {
61+
let mut out = random(shape);
62+
for _ in 0..10 {
63+
// check rank
64+
if let Ok(out_rank) = out.rank() {
65+
if out_rank == rank {
66+
return out;
67+
}
68+
}
69+
70+
out.mapv_inplace(|_| A::rand(&mut rng));
71+
}
72+
73+
// handle partial-rank case
74+
//
75+
// multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
76+
// produce `an m × n` array with rank `r`
77+
// https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
78+
} else {
79+
let mut out = Array2::zeros(shape);
80+
let mut left: Array2<A> = random([out.nrows(), rank]);
81+
let mut right: Array2<A> = random([rank, out.ncols()]);
82+
83+
for _ in 0..10 {
7084
general_mat_mul(A::one(), &left, &right, A::zero(), &mut out);
71-
out
72-
};
7385

74-
// check rank
75-
if let Ok(out_rank) = out.rank() {
76-
if out_rank == rank {
77-
return out;
86+
// check rank
87+
if let Ok(out_rank) = out.rank() {
88+
if out_rank == rank {
89+
return out;
90+
}
7891
}
92+
93+
left.mapv_inplace(|_| A::rand(&mut rng));
94+
right.mapv_inplace(|_| A::rand(&mut rng));
7995
}
8096
}
8197

0 commit comments

Comments
 (0)