@@ -149,12 +149,13 @@ where
149
149
150
150
/// Solve least squares for immutable references and a single
151
151
/// column vector as a right-hand side.
152
- /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D ` can be any
153
- /// valid representation for `ArrayBase`.
154
- impl < E , D > LeastSquaresSvd < D , E , Ix1 > for ArrayBase < D , Ix2 >
152
+ /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2 ` can be any
153
+ /// valid representation for `ArrayBase` (over `E`) .
154
+ impl < E , D1 , D2 > LeastSquaresSvd < D2 , E , Ix1 > for ArrayBase < D1 , Ix2 >
155
155
where
156
156
E : Scalar + Lapack ,
157
- D : Data < Elem = E > ,
157
+ D1 : Data < Elem = E > ,
158
+ D2 : Data < Elem = E > ,
158
159
{
159
160
/// Solve a least squares problem of the form `Ax = rhs`
160
161
/// by calling `A.least_squares(&rhs)`, where `rhs` is a
@@ -163,7 +164,7 @@ where
163
164
/// `A` and `rhs` must have the same layout, i.e. they must
164
165
/// be both either row- or column-major format, otherwise a
165
166
/// `IncompatibleShape` error is raised.
166
- fn least_squares ( & self , rhs : & ArrayBase < D , Ix1 > ) -> Result < LeastSquaresResult < E , Ix1 > > {
167
+ fn least_squares ( & self , rhs : & ArrayBase < D2 , Ix1 > ) -> Result < LeastSquaresResult < E , Ix1 > > {
167
168
let a = self . to_owned ( ) ;
168
169
let b = rhs. to_owned ( ) ;
169
170
a. least_squares_into ( b)
@@ -172,12 +173,13 @@ where
172
173
173
174
/// Solve least squares for immutable references and matrix
174
175
/// (=mulitipe vectors) as a right-hand side.
175
- /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D ` can be any
176
- /// valid representation for `ArrayBase`.
177
- impl < E , D > LeastSquaresSvd < D , E , Ix2 > for ArrayBase < D , Ix2 >
176
+ /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2 ` can be any
177
+ /// valid representation for `ArrayBase` (over `E`) .
178
+ impl < E , D1 , D2 > LeastSquaresSvd < D2 , E , Ix2 > for ArrayBase < D1 , Ix2 >
178
179
where
179
180
E : Scalar + Lapack ,
180
- D : Data < Elem = E > ,
181
+ D1 : Data < Elem = E > ,
182
+ D2 : Data < Elem = E > ,
181
183
{
182
184
/// Solve a least squares problem of the form `Ax = rhs`
183
185
/// by calling `A.least_squares(&rhs)`, where `rhs` is
@@ -186,7 +188,7 @@ where
186
188
/// `A` and `rhs` must have the same layout, i.e. they must
187
189
/// be both either row- or column-major format, otherwise a
188
190
/// `IncompatibleShape` error is raised.
189
- fn least_squares ( & self , rhs : & ArrayBase < D , Ix2 > ) -> Result < LeastSquaresResult < E , Ix2 > > {
191
+ fn least_squares ( & self , rhs : & ArrayBase < D2 , Ix2 > ) -> Result < LeastSquaresResult < E , Ix2 > > {
190
192
let a = self . to_owned ( ) ;
191
193
let b = rhs. to_owned ( ) ;
192
194
a. least_squares_into ( b)
@@ -199,10 +201,11 @@ where
199
201
///
200
202
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
201
203
/// valid representation for `ArrayBase`.
202
- impl < E , D > LeastSquaresSvdInto < D , E , Ix1 > for ArrayBase < D , Ix2 >
204
+ impl < E , D1 , D2 > LeastSquaresSvdInto < D2 , E , Ix1 > for ArrayBase < D1 , Ix2 >
203
205
where
204
206
E : Scalar + Lapack ,
205
- D : DataMut < Elem = E > ,
207
+ D1 : DataMut < Elem = E > ,
208
+ D2 : DataMut < Elem = E > ,
206
209
{
207
210
/// Solve a least squares problem of the form `Ax = rhs`
208
211
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -213,7 +216,7 @@ where
213
216
/// `IncompatibleShape` error is raised.
214
217
fn least_squares_into (
215
218
mut self ,
216
- mut rhs : ArrayBase < D , Ix1 > ,
219
+ mut rhs : ArrayBase < D2 , Ix1 > ,
217
220
) -> Result < LeastSquaresResult < E , Ix1 > > {
218
221
self . least_squares_in_place ( & mut rhs)
219
222
}
@@ -223,12 +226,13 @@ where
223
226
/// as a right-hand side. The matrix and the RHS matrix
224
227
/// are consumed.
225
228
///
226
- /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D ` can be any
227
- /// valid representation for `ArrayBase`.
228
- impl < E , D > LeastSquaresSvdInto < D , E , Ix2 > for ArrayBase < D , Ix2 >
229
+ /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2 ` can be any
230
+ /// valid representation for `ArrayBase` (over `E`) .
231
+ impl < E , D1 , D2 > LeastSquaresSvdInto < D2 , E , Ix2 > for ArrayBase < D1 , Ix2 >
229
232
where
230
233
E : Scalar + Lapack ,
231
- D : DataMut < Elem = E > ,
234
+ D1 : DataMut < Elem = E > ,
235
+ D2 : DataMut < Elem = E > ,
232
236
{
233
237
/// Solve a least squares problem of the form `Ax = rhs`
234
238
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -239,7 +243,7 @@ where
239
243
/// `IncompatibleShape` error is raised.
240
244
fn least_squares_into (
241
245
mut self ,
242
- mut rhs : ArrayBase < D , Ix2 > ,
246
+ mut rhs : ArrayBase < D2 , Ix2 > ,
243
247
) -> Result < LeastSquaresResult < E , Ix2 > > {
244
248
self . least_squares_in_place ( & mut rhs)
245
249
}
@@ -249,12 +253,13 @@ where
249
253
/// as a right-hand side. Both values are overwritten in the
250
254
/// call.
251
255
///
252
- /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D ` can be any
253
- /// valid representation for `ArrayBase`.
254
- impl < E , D > LeastSquaresSvdInPlace < D , E , Ix1 > for ArrayBase < D , Ix2 >
256
+ /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2 ` can be any
257
+ /// valid representation for `ArrayBase` (over `E`) .
258
+ impl < E , D1 , D2 > LeastSquaresSvdInPlace < D2 , E , Ix1 > for ArrayBase < D1 , Ix2 >
255
259
where
256
260
E : Scalar + Lapack ,
257
- D : DataMut < Elem = E > ,
261
+ D1 : DataMut < Elem = E > ,
262
+ D2 : DataMut < Elem = E > ,
258
263
{
259
264
/// Solve a least squares problem of the form `Ax = rhs`
260
265
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -265,7 +270,7 @@ where
265
270
/// `IncompatibleShape` error is raised.
266
271
fn least_squares_in_place (
267
272
& mut self ,
268
- rhs : & mut ArrayBase < D , Ix1 > ,
273
+ rhs : & mut ArrayBase < D2 , Ix1 > ,
269
274
) -> Result < LeastSquaresResult < E , Ix1 > > {
270
275
if self . shape ( ) [ 0 ] != rhs. shape ( ) [ 0 ] {
271
276
return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) . into ( ) ) ;
@@ -331,12 +336,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
331
336
/// as a right-hand side. Both values are overwritten in the
332
337
/// call.
333
338
///
334
- /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D ` can be any
335
- /// valid representation for `ArrayBase`.
336
- impl < E , D > LeastSquaresSvdInPlace < D , E , Ix2 > for ArrayBase < D , Ix2 >
339
+ /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2 ` can be any
340
+ /// valid representation for `ArrayBase` (over `E`) .
341
+ impl < E , D1 , D2 > LeastSquaresSvdInPlace < D2 , E , Ix2 > for ArrayBase < D1 , Ix2 >
337
342
where
338
343
E : Scalar + Lapack + LeastSquaresSvdDivideConquer_ ,
339
- D : DataMut < Elem = E > ,
344
+ D1 : DataMut < Elem = E > ,
345
+ D2 : DataMut < Elem = E > ,
340
346
{
341
347
/// Solve a least squares problem of the form `Ax = rhs`
342
348
/// by calling `A.least_squares(rhs)`, where `rhs` is a
@@ -347,7 +353,7 @@ where
347
353
/// `IncompatibleShape` error is raised.
348
354
fn least_squares_in_place (
349
355
& mut self ,
350
- rhs : & mut ArrayBase < D , Ix2 > ,
356
+ rhs : & mut ArrayBase < D2 , Ix2 > ,
351
357
) -> Result < LeastSquaresResult < E , Ix2 > > {
352
358
if self . shape ( ) [ 0 ] != rhs. shape ( ) [ 0 ] {
353
359
return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) . into ( ) ) ;
@@ -425,7 +431,7 @@ mod tests {
425
431
use ndarray:: * ;
426
432
427
433
//
428
- // Test that the different lest squares traits work as intended on the
434
+ // Test that the different least squares traits work as intended on the
429
435
// different array types.
430
436
//
431
437
// | least_squares | ls_into | ls_in_place |
@@ -437,9 +443,9 @@ mod tests {
437
443
// ArrayViewMut | yes | no | yes |
438
444
//
439
445
440
- fn assert_result < D : Data < Elem = f64 > > (
441
- a : & ArrayBase < D , Ix2 > ,
442
- b : & ArrayBase < D , Ix1 > ,
446
+ fn assert_result < D1 : Data < Elem = f64 > , D2 : Data < Elem = f64 > > (
447
+ a : & ArrayBase < D1 , Ix2 > ,
448
+ b : & ArrayBase < D2 , Ix1 > ,
443
449
res : & LeastSquaresResult < f64 , Ix1 > ,
444
450
) {
445
451
assert_eq ! ( res. rank, 2 ) ;
@@ -487,6 +493,15 @@ mod tests {
487
493
assert_result ( & av, & bv, & res) ;
488
494
}
489
495
496
+ #[ test]
497
+ fn on_cow_view ( ) {
498
+ let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
499
+ let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
500
+ let bv = b. view ( ) ;
501
+ let res = a. least_squares ( & bv) . unwrap ( ) ;
502
+ assert_result ( & a, & bv, & res) ;
503
+ }
504
+
490
505
#[ test]
491
506
fn into_on_owned ( ) {
492
507
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
@@ -517,6 +532,16 @@ mod tests {
517
532
assert_result ( & a, & b, & res) ;
518
533
}
519
534
535
+ #[ test]
536
+ fn into_on_owned_cow ( ) {
537
+ let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
538
+ let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
539
+ let ac = a. clone ( ) ;
540
+ let b2 = b. clone ( ) ;
541
+ let res = ac. least_squares_into ( b2) . unwrap ( ) ;
542
+ assert_result ( & a, & b, & res) ;
543
+ }
544
+
520
545
#[ test]
521
546
fn in_place_on_owned ( ) {
522
547
let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
@@ -549,6 +574,16 @@ mod tests {
549
574
assert_result ( & a, & b, & res) ;
550
575
}
551
576
577
+ #[ test]
578
+ fn in_place_on_owned_cow ( ) {
579
+ let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
580
+ let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
581
+ let mut a2 = a. clone ( ) ;
582
+ let mut b2 = b. clone ( ) ;
583
+ let res = a2. least_squares_in_place ( & mut b2) . unwrap ( ) ;
584
+ assert_result ( & a, & b, & res) ;
585
+ }
586
+
552
587
//
553
588
// Testing error cases
554
589
//
0 commit comments