Skip to content

Relax type bounds for LeastSquaresSvd family #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 67 additions & 32 deletions ndarray-linalg/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,13 @@ where

/// Solve least squares for immutable references and a single
/// column vector as a right-hand side.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvd<D, E, Ix1> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix1> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: Data<Elem = E>,
D1: Data<Elem = E>,
D2: Data<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(&rhs)`, where `rhs` is a
Expand All @@ -163,7 +164,7 @@ where
/// `A` and `rhs` must have the same layout, i.e. they must
/// be both either row- or column-major format, otherwise a
/// `IncompatibleShape` error is raised.
fn least_squares(&self, rhs: &ArrayBase<D, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
fn least_squares(&self, rhs: &ArrayBase<D2, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
let a = self.to_owned();
let b = rhs.to_owned();
a.least_squares_into(b)
Expand All @@ -172,12 +173,13 @@ where

/// Solve least squares for immutable references and matrix
/// (=mulitipe vectors) as a right-hand side.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvd<D, E, Ix2> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix2> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: Data<Elem = E>,
D1: Data<Elem = E>,
D2: Data<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(&rhs)`, where `rhs` is
Expand All @@ -186,7 +188,7 @@ where
/// `A` and `rhs` must have the same layout, i.e. they must
/// be both either row- or column-major format, otherwise a
/// `IncompatibleShape` error is raised.
fn least_squares(&self, rhs: &ArrayBase<D, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
fn least_squares(&self, rhs: &ArrayBase<D2, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
let a = self.to_owned();
let b = rhs.to_owned();
a.least_squares_into(b)
Expand All @@ -199,10 +201,11 @@ where
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInto<D, E, Ix1> for ArrayBase<D, Ix2>
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -213,7 +216,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_into(
mut self,
mut rhs: ArrayBase<D, Ix1>,
mut rhs: ArrayBase<D2, Ix1>,
) -> Result<LeastSquaresResult<E, Ix1>> {
self.least_squares_in_place(&mut rhs)
}
Expand All @@ -223,12 +226,13 @@ where
/// as a right-hand side. The matrix and the RHS matrix
/// are consumed.
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInto<D, E, Ix2> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -239,7 +243,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_into(
mut self,
mut rhs: ArrayBase<D, Ix2>,
mut rhs: ArrayBase<D2, Ix2>,
) -> Result<LeastSquaresResult<E, Ix2>> {
self.least_squares_in_place(&mut rhs)
}
Expand All @@ -249,12 +253,13 @@ where
/// as a right-hand side. Both values are overwritten in the
/// call.
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix1> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix1> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -265,7 +270,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_in_place(
&mut self,
rhs: &mut ArrayBase<D, Ix1>,
rhs: &mut ArrayBase<D2, Ix1>,
) -> Result<LeastSquaresResult<E, Ix1>> {
if self.shape()[0] != rhs.shape()[0] {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
Expand Down Expand Up @@ -331,12 +336,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
/// as a right-hand side. Both values are overwritten in the
/// call.
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix2> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack + LeastSquaresSvdDivideConquer_,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -347,7 +353,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_in_place(
&mut self,
rhs: &mut ArrayBase<D, Ix2>,
rhs: &mut ArrayBase<D2, Ix2>,
) -> Result<LeastSquaresResult<E, Ix2>> {
if self.shape()[0] != rhs.shape()[0] {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
Expand Down Expand Up @@ -425,7 +431,7 @@ mod tests {
use ndarray::*;

//
// Test that the different lest squares traits work as intended on the
// Test that the different least squares traits work as intended on the
// different array types.
//
// | least_squares | ls_into | ls_in_place |
Expand All @@ -437,9 +443,9 @@ mod tests {
// ArrayViewMut | yes | no | yes |
//

fn assert_result<D: Data<Elem = f64>>(
a: &ArrayBase<D, Ix2>,
b: &ArrayBase<D, Ix1>,
fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix1>,
res: &LeastSquaresResult<f64, Ix1>,
) {
assert_eq!(res.rank, 2);
Expand Down Expand Up @@ -487,6 +493,15 @@ mod tests {
assert_result(&av, &bv, &res);
}

#[test]
fn on_cow_view() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b: Array1<f64> = array![1., 2., 3.];
let bv = b.view();
let res = a.least_squares(&bv).unwrap();
assert_result(&a, &bv, &res);
}

#[test]
fn into_on_owned() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
Expand Down Expand Up @@ -517,6 +532,16 @@ mod tests {
assert_result(&a, &b, &res);
}

#[test]
fn into_on_owned_cow() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b = CowArray::from(array![1., 2., 3.]);
let ac = a.clone();
let b2 = b.clone();
let res = ac.least_squares_into(b2).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn in_place_on_owned() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
Expand Down Expand Up @@ -549,6 +574,16 @@ mod tests {
assert_result(&a, &b, &res);
}

#[test]
fn in_place_on_owned_cow() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = CowArray::from(array![1., 2., 3.]);
let mut a2 = a.clone();
let mut b2 = b.clone();
let res = a2.least_squares_in_place(&mut b2).unwrap();
assert_result(&a, &b, &res);
}

//
// Testing error cases
//
Expand Down