Skip to content

Commit 4111835

Browse files
authored
Merge pull request #272 from janmarthedal/relax-lstsq-types
Relax type bounds for LeastSquaresSvd family
2 parents aee87d9 + a23224f commit 4111835

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

ndarray-linalg/src/least_squares.rs

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,13 @@ where
149149

150150
/// Solve least squares for immutable references and a single
151151
/// column vector as a right-hand side.
152-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
153-
/// valid representation for `ArrayBase`.
154-
impl<E, D> LeastSquaresSvd<D, E, Ix1> for ArrayBase<D, Ix2>
152+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
153+
/// valid representation for `ArrayBase` (over `E`).
154+
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix1> for ArrayBase<D1, Ix2>
155155
where
156156
E: Scalar + Lapack,
157-
D: Data<Elem = E>,
157+
D1: Data<Elem = E>,
158+
D2: Data<Elem = E>,
158159
{
159160
/// Solve a least squares problem of the form `Ax = rhs`
160161
/// by calling `A.least_squares(&rhs)`, where `rhs` is a
@@ -163,7 +164,7 @@ where
163164
/// `A` and `rhs` must have the same layout, i.e. they must
164165
/// be both either row- or column-major format, otherwise a
165166
/// `IncompatibleShape` error is raised.
166-
fn least_squares(&self, rhs: &ArrayBase<D, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
167+
fn least_squares(&self, rhs: &ArrayBase<D2, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
167168
let a = self.to_owned();
168169
let b = rhs.to_owned();
169170
a.least_squares_into(b)
@@ -172,12 +173,13 @@ where
172173

173174
/// Solve least squares for immutable references and matrix
174175
/// (=mulitipe vectors) as a right-hand side.
175-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
176-
/// valid representation for `ArrayBase`.
177-
impl<E, D> LeastSquaresSvd<D, E, Ix2> for ArrayBase<D, Ix2>
176+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
177+
/// valid representation for `ArrayBase` (over `E`).
178+
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix2> for ArrayBase<D1, Ix2>
178179
where
179180
E: Scalar + Lapack,
180-
D: Data<Elem = E>,
181+
D1: Data<Elem = E>,
182+
D2: Data<Elem = E>,
181183
{
182184
/// Solve a least squares problem of the form `Ax = rhs`
183185
/// by calling `A.least_squares(&rhs)`, where `rhs` is
@@ -186,7 +188,7 @@ where
186188
/// `A` and `rhs` must have the same layout, i.e. they must
187189
/// be both either row- or column-major format, otherwise a
188190
/// `IncompatibleShape` error is raised.
189-
fn least_squares(&self, rhs: &ArrayBase<D, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
191+
fn least_squares(&self, rhs: &ArrayBase<D2, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
190192
let a = self.to_owned();
191193
let b = rhs.to_owned();
192194
a.least_squares_into(b)
@@ -199,10 +201,11 @@ where
199201
///
200202
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
201203
/// valid representation for `ArrayBase`.
202-
impl<E, D> LeastSquaresSvdInto<D, E, Ix1> for ArrayBase<D, Ix2>
204+
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
203205
where
204206
E: Scalar + Lapack,
205-
D: DataMut<Elem = E>,
207+
D1: DataMut<Elem = E>,
208+
D2: DataMut<Elem = E>,
206209
{
207210
/// Solve a least squares problem of the form `Ax = rhs`
208211
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -213,7 +216,7 @@ where
213216
/// `IncompatibleShape` error is raised.
214217
fn least_squares_into(
215218
mut self,
216-
mut rhs: ArrayBase<D, Ix1>,
219+
mut rhs: ArrayBase<D2, Ix1>,
217220
) -> Result<LeastSquaresResult<E, Ix1>> {
218221
self.least_squares_in_place(&mut rhs)
219222
}
@@ -223,12 +226,13 @@ where
223226
/// as a right-hand side. The matrix and the RHS matrix
224227
/// are consumed.
225228
///
226-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
227-
/// valid representation for `ArrayBase`.
228-
impl<E, D> LeastSquaresSvdInto<D, E, Ix2> for ArrayBase<D, Ix2>
229+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
230+
/// valid representation for `ArrayBase` (over `E`).
231+
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
229232
where
230233
E: Scalar + Lapack,
231-
D: DataMut<Elem = E>,
234+
D1: DataMut<Elem = E>,
235+
D2: DataMut<Elem = E>,
232236
{
233237
/// Solve a least squares problem of the form `Ax = rhs`
234238
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -239,7 +243,7 @@ where
239243
/// `IncompatibleShape` error is raised.
240244
fn least_squares_into(
241245
mut self,
242-
mut rhs: ArrayBase<D, Ix2>,
246+
mut rhs: ArrayBase<D2, Ix2>,
243247
) -> Result<LeastSquaresResult<E, Ix2>> {
244248
self.least_squares_in_place(&mut rhs)
245249
}
@@ -249,12 +253,13 @@ where
249253
/// as a right-hand side. Both values are overwritten in the
250254
/// call.
251255
///
252-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
253-
/// valid representation for `ArrayBase`.
254-
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix1> for ArrayBase<D, Ix2>
256+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
257+
/// valid representation for `ArrayBase` (over `E`).
258+
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix1> for ArrayBase<D1, Ix2>
255259
where
256260
E: Scalar + Lapack,
257-
D: DataMut<Elem = E>,
261+
D1: DataMut<Elem = E>,
262+
D2: DataMut<Elem = E>,
258263
{
259264
/// Solve a least squares problem of the form `Ax = rhs`
260265
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -265,7 +270,7 @@ where
265270
/// `IncompatibleShape` error is raised.
266271
fn least_squares_in_place(
267272
&mut self,
268-
rhs: &mut ArrayBase<D, Ix1>,
273+
rhs: &mut ArrayBase<D2, Ix1>,
269274
) -> Result<LeastSquaresResult<E, Ix1>> {
270275
if self.shape()[0] != rhs.shape()[0] {
271276
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
@@ -331,12 +336,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
331336
/// as a right-hand side. Both values are overwritten in the
332337
/// call.
333338
///
334-
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
335-
/// valid representation for `ArrayBase`.
336-
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix2> for ArrayBase<D, Ix2>
339+
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
340+
/// valid representation for `ArrayBase` (over `E`).
341+
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
337342
where
338343
E: Scalar + Lapack + LeastSquaresSvdDivideConquer_,
339-
D: DataMut<Elem = E>,
344+
D1: DataMut<Elem = E>,
345+
D2: DataMut<Elem = E>,
340346
{
341347
/// Solve a least squares problem of the form `Ax = rhs`
342348
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -347,7 +353,7 @@ where
347353
/// `IncompatibleShape` error is raised.
348354
fn least_squares_in_place(
349355
&mut self,
350-
rhs: &mut ArrayBase<D, Ix2>,
356+
rhs: &mut ArrayBase<D2, Ix2>,
351357
) -> Result<LeastSquaresResult<E, Ix2>> {
352358
if self.shape()[0] != rhs.shape()[0] {
353359
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
@@ -425,7 +431,7 @@ mod tests {
425431
use ndarray::*;
426432

427433
//
428-
// Test that the different lest squares traits work as intended on the
434+
// Test that the different least squares traits work as intended on the
429435
// different array types.
430436
//
431437
// | least_squares | ls_into | ls_in_place |
@@ -437,9 +443,9 @@ mod tests {
437443
// ArrayViewMut | yes | no | yes |
438444
//
439445

440-
fn assert_result<D: Data<Elem = f64>>(
441-
a: &ArrayBase<D, Ix2>,
442-
b: &ArrayBase<D, Ix1>,
446+
fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
447+
a: &ArrayBase<D1, Ix2>,
448+
b: &ArrayBase<D2, Ix1>,
443449
res: &LeastSquaresResult<f64, Ix1>,
444450
) {
445451
assert_eq!(res.rank, 2);
@@ -487,6 +493,15 @@ mod tests {
487493
assert_result(&av, &bv, &res);
488494
}
489495

496+
#[test]
497+
fn on_cow_view() {
498+
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
499+
let b: Array1<f64> = array![1., 2., 3.];
500+
let bv = b.view();
501+
let res = a.least_squares(&bv).unwrap();
502+
assert_result(&a, &bv, &res);
503+
}
504+
490505
#[test]
491506
fn into_on_owned() {
492507
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
@@ -517,6 +532,16 @@ mod tests {
517532
assert_result(&a, &b, &res);
518533
}
519534

535+
#[test]
536+
fn into_on_owned_cow() {
537+
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
538+
let b = CowArray::from(array![1., 2., 3.]);
539+
let ac = a.clone();
540+
let b2 = b.clone();
541+
let res = ac.least_squares_into(b2).unwrap();
542+
assert_result(&a, &b, &res);
543+
}
544+
520545
#[test]
521546
fn in_place_on_owned() {
522547
let a = array![[1., 2.], [4., 5.], [3., 4.]];
@@ -549,6 +574,16 @@ mod tests {
549574
assert_result(&a, &b, &res);
550575
}
551576

577+
#[test]
578+
fn in_place_on_owned_cow() {
579+
let a = array![[1., 2.], [4., 5.], [3., 4.]];
580+
let b = CowArray::from(array![1., 2., 3.]);
581+
let mut a2 = a.clone();
582+
let mut b2 = b.clone();
583+
let res = a2.least_squares_in_place(&mut b2).unwrap();
584+
assert_result(&a, &b, &res);
585+
}
586+
552587
//
553588
// Testing error cases
554589
//

0 commit comments

Comments
 (0)