Skip to content

Commit 4ca116f

Browse files
committed
Add MP test
1 parent 088f0b1 commit 4ca116f

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

ndarray-linalg/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ paste = "1.0"
5050
criterion = "0.3"
5151
# Keep the same version as ndarray's dependency!
5252
approx = { version = "0.4", features = ["num-complex"] }
53-
rand_xoshiro = "0.6"
53+
rand_xoshiro = "0.4"
54+
ndarray-rand = "0.12"
5455

5556
[[bench]]
5657
name = "truncated_eig"

ndarray-linalg/src/lobpcg/eig.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ use num_traits::{Float, NumCast};
2020
/// # Example
2121
///
2222
/// ```rust
23+
/// use ndarray::{arr1, Array2};
24+
/// use ndarray_linalg::{TruncatedEig, TruncatedOrder};
25+
///
2326
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
2427
/// let a = Array2::from_diag(&diag);
2528
///
26-
/// let eig = TruncatedEig::new(a, Order::Largest)
29+
/// let eig = TruncatedEig::new(a, TruncatedOrder::Largest)
2730
/// .precision(1e-5)
2831
/// .maxiter(500);
2932
///
30-
/// let res = eig.decompose();
33+
/// let res = eig.decompose(3);
3134
/// ```
3235
3336
pub struct TruncatedEig<A: Scalar> {
@@ -109,14 +112,17 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Truncate
109112
/// # Example
110113
///
111114
/// ```rust
115+
/// use ndarray::{arr1, Array2};
116+
/// use ndarray_linalg::{TruncatedEig, TruncatedOrder};
117+
///
112118
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
113119
/// let a = Array2::from_diag(&diag);
114120
///
115-
/// let eig = TruncatedEig::new(a, Order::Largest)
121+
/// let eig = TruncatedEig::new(a, TruncatedOrder::Largest)
116122
/// .precision(1e-5)
117123
/// .maxiter(500);
118124
///
119-
/// let res = eig.decompose();
125+
/// let res = eig.decompose(3);
120126
/// ```
121127
pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
122128
let x: Array2<f64> = generate::random((self.problem.len_of(Axis(0)), num));
@@ -169,15 +175,21 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> IntoIter
169175
/// # Example
170176
///
171177
/// ```rust
172-
/// let teig = TruncatedEig::new(a, Order::Largest)
178+
/// use ndarray::{arr1, Array2};
179+
/// use ndarray_linalg::{TruncatedEig, TruncatedOrder};
180+
///
181+
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
182+
/// let a = Array2::from_diag(&diag);
183+
///
184+
/// let teig = TruncatedEig::new(a, TruncatedOrder::Largest)
173185
/// .precision(1e-5)
174186
/// .maxiter(500);
175187
///
176188
/// // solve eigenproblem until eigenvalues get smaller than 0.5
177189
/// let res = teig.into_iter()
178190
/// .take_while(|x| x.0[0] > 0.5)
179-
/// .flat_map(|x| x.0)
180-
/// .collect();
191+
/// .flat_map(|x| x.0.to_vec())
192+
/// .collect::<Vec<_>>();
181193
/// ```
182194
pub struct TruncatedEigIterator<A: Scalar> {
183195
step_size: usize,

ndarray-linalg/src/lobpcg/svd.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,11 @@ mod tests {
200200
use super::TruncatedSvd;
201201
use crate::{close_l2, generate};
202202

203-
use ndarray::{arr1, arr2, Array2};
203+
use rand::SeedableRng;
204204
use rand_xoshiro::Xoshiro256Plus;
205+
use ndarray::{arr1, arr2, Array1, Array2};
206+
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
207+
205208
use approx::assert_abs_diff_eq;
206209

207210
#[test]
@@ -246,14 +249,16 @@ mod tests {
246249
#[test]
247250
fn test_marchenko_pastur() {
248251
// create random number generator
249-
let mut rng = SmallRng::seed_from_u64(3);
252+
let mut rng = Xoshiro256Plus::seed_from_u64(3);
250253

251254
// generate normal distribution random data with N >> p
252-
let data = Array2::random_using((1000, 500), StandardNormal, &mut rng);
253-
let dataset = Dataset::from(data / 1000f64.sqrt());
255+
let data = Array2::random_using((1000, 500), StandardNormal, &mut rng) / 1000f64.sqrt();
256+
257+
let res = TruncatedSvd::new(data, Order::Largest)
258+
.decompose(500)
259+
.unwrap();
254260

255-
let model = Pca::params(500).fit(&dataset);
256-
let sv = model.singular_values().mapv(|x| x * x);
261+
let sv = res.values().mapv(|x: f64| x*x);
257262

258263
// we have created a random spectrum and can apply the Marchenko-Pastur law
259264
// with variance 1 and p/n = 0.5

0 commit comments

Comments
 (0)