@@ -475,6 +475,7 @@ def _batched_qr(a, mode="reduced"):
475
475
)
476
476
477
477
478
+ # pylint: disable=too-many-locals
478
479
def _batched_svd (
479
480
a ,
480
481
uv_type ,
@@ -532,29 +533,30 @@ def _batched_svd(
532
533
batch_shape_orig ,
533
534
)
534
535
535
- k = min (m , n )
536
- if compute_uv :
537
- if full_matrices :
538
- u_shape = (m , m ) + (batch_size ,)
539
- vt_shape = (n , n ) + (batch_size ,)
540
- jobu = ord ("A" )
541
- jobvt = ord ("A" )
542
- else :
543
- u_shape = (m , k ) + (batch_size ,)
544
- vt_shape = (k , n ) + (batch_size ,)
545
- jobu = ord ("S" )
546
- jobvt = ord ("S" )
536
+ # Transpose if m < n:
537
+ # 1. cuSolver gesvd supports only m >= n
538
+ # 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
539
+ if m < n :
540
+ n , m = a .shape [- 2 :]
541
+ trans_flag = True
547
542
else :
548
- u_shape = vt_shape = ()
549
- jobu = ord ("N" )
550
- jobvt = ord ("N" )
543
+ trans_flag = False
544
+
545
+ u_shape , vt_shape , s_shape , jobu , jobvt = _get_svd_shapes_and_flags (
546
+ m , n , compute_uv , full_matrices , batch_size = batch_size
547
+ )
551
548
552
549
_manager = dpu .SequentialOrderManager [exec_q ]
553
550
dep_evs = _manager .submitted_events
554
551
555
552
# Reorder the elements by moving the last two axes of `a` to the front
556
553
# to match fortran-like array order which is assumed by gesvd.
557
- a = dpnp .moveaxis (a , (- 2 , - 1 ), (0 , 1 ))
554
+ if trans_flag :
555
+ # Transpose axes for cuSolver and to optimize reduction
556
+ # to bidiagonal form
557
+ a = dpnp .moveaxis (a , (- 1 , - 2 ), (0 , 1 ))
558
+ else :
559
+ a = dpnp .moveaxis (a , (- 2 , - 1 ), (0 , 1 ))
558
560
559
561
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
560
562
# as input.
@@ -583,7 +585,7 @@ def _batched_svd(
583
585
sycl_queue = exec_q ,
584
586
)
585
587
s_h = dpnp .empty (
586
- ( batch_size ,) + ( k ,) ,
588
+ s_shape ,
587
589
dtype = s_type ,
588
590
order = "C" ,
589
591
usm_type = usm_type ,
@@ -607,16 +609,23 @@ def _batched_svd(
607
609
# gesvd call writes `u_h` and `vt_h` in Fortran order;
608
610
# reorder the axes to match C order by moving the last axis
609
611
# to the front
610
- u = dpnp .moveaxis (u_h , - 1 , 0 )
611
- vt = dpnp .moveaxis (vt_h , - 1 , 0 )
612
+ if trans_flag :
613
+ # Transpose axes to restore U and V^T for the original matrix
614
+ u = dpnp .moveaxis (u_h , (0 , - 1 ), (- 1 , 0 ))
615
+ vt = dpnp .moveaxis (vt_h , (0 , - 1 ), (- 1 , 0 ))
616
+ else :
617
+ u = dpnp .moveaxis (u_h , - 1 , 0 )
618
+ vt = dpnp .moveaxis (vt_h , - 1 , 0 )
619
+
612
620
if a_ndim > 3 :
613
621
u = u .reshape (batch_shape_orig + u .shape [- 2 :])
614
622
vt = vt .reshape (batch_shape_orig + vt .shape [- 2 :])
615
623
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
616
624
# Convert to contiguous to align with NumPy
617
625
u = dpnp .ascontiguousarray (u )
618
626
vt = dpnp .ascontiguousarray (vt )
619
- return u , s , vt
627
+ # Swap `u` and `vt` for transposed input to restore correct order
628
+ return (vt , s , u ) if trans_flag else (u , s , vt )
620
629
return s
621
630
622
631
@@ -759,6 +768,36 @@ def _common_inexact_type(default_dtype, *dtypes):
759
768
return dpnp .result_type (* inexact_dtypes )
760
769
761
770
771
+ def _get_svd_shapes_and_flags (m , n , compute_uv , full_matrices , batch_size = None ):
772
+ """Return the shapes and flags for SVD computations."""
773
+
774
+ k = min (m , n )
775
+ if compute_uv :
776
+ if full_matrices :
777
+ u_shape = (m , m )
778
+ vt_shape = (n , n )
779
+ jobu = ord ("A" )
780
+ jobvt = ord ("A" )
781
+ else :
782
+ u_shape = (m , k )
783
+ vt_shape = (k , n )
784
+ jobu = ord ("S" )
785
+ jobvt = ord ("S" )
786
+ else :
787
+ u_shape = vt_shape = ()
788
+ jobu = ord ("N" )
789
+ jobvt = ord ("N" )
790
+
791
+ s_shape = (k ,)
792
+ if batch_size is not None :
793
+ if compute_uv :
794
+ u_shape += (batch_size ,)
795
+ vt_shape += (batch_size ,)
796
+ s_shape = (batch_size ,) + s_shape
797
+
798
+ return u_shape , vt_shape , s_shape , jobu , jobvt
799
+
800
+
762
801
def _hermitian_svd (a , compute_uv ):
763
802
"""
764
803
_hermitian_svd(a, compute_uv)
@@ -2695,6 +2734,16 @@ def dpnp_svd(
2695
2734
a , uv_type , s_type , full_matrices , compute_uv , exec_q , usm_type
2696
2735
)
2697
2736
2737
+ # Transpose if m < n:
2738
+ # 1. cuSolver gesvd supports only m >= n
2739
+ # 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
2740
+ if m < n :
2741
+ n , m = a .shape
2742
+ a = a .transpose ()
2743
+ trans_flag = True
2744
+ else :
2745
+ trans_flag = False
2746
+
2698
2747
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
2699
2748
# Allocate 'F' order memory for dpnp arrays to comply with
2700
2749
# these requirements.
@@ -2716,22 +2765,9 @@ def dpnp_svd(
2716
2765
)
2717
2766
_manager .add_event_pair (ht_ev , copy_ev )
2718
2767
2719
- k = min (m , n )
2720
- if compute_uv :
2721
- if full_matrices :
2722
- u_shape = (m , m )
2723
- vt_shape = (n , n )
2724
- jobu = ord ("A" )
2725
- jobvt = ord ("A" )
2726
- else :
2727
- u_shape = (m , k )
2728
- vt_shape = (k , n )
2729
- jobu = ord ("S" )
2730
- jobvt = ord ("S" )
2731
- else :
2732
- u_shape = vt_shape = ()
2733
- jobu = ord ("N" )
2734
- jobvt = ord ("N" )
2768
+ u_shape , vt_shape , s_shape , jobu , jobvt = _get_svd_shapes_and_flags (
2769
+ m , n , compute_uv , full_matrices
2770
+ )
2735
2771
2736
2772
# oneMKL LAPACK assumes fortran-like array as input.
2737
2773
# Allocate 'F' order memory for dpnp output arrays to comply with
@@ -2746,7 +2782,7 @@ def dpnp_svd(
2746
2782
shape = vt_shape ,
2747
2783
order = "F" ,
2748
2784
)
2749
- s_h = dpnp .empty_like (a_h , shape = ( k ,) , dtype = s_type )
2785
+ s_h = dpnp .empty_like (a_h , shape = s_shape , dtype = s_type )
2750
2786
2751
2787
ht_ev , gesvd_ev = li ._gesvd (
2752
2788
exec_q ,
@@ -2761,6 +2797,11 @@ def dpnp_svd(
2761
2797
_manager .add_event_pair (ht_ev , gesvd_ev )
2762
2798
2763
2799
if compute_uv :
2800
+ # Transposing the input matrix swaps the roles of U and Vt:
2801
+ # For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
2802
+ # Transpose and swap them back to restore correct order for A.
2803
+ if trans_flag :
2804
+ return vt_h .T , s_h , u_h .T
2764
2805
# gesvd call writes `u_h` and `vt_h` in Fortran order;
2765
2806
# Convert to contiguous to align with NumPy
2766
2807
u_h = dpnp .ascontiguousarray (u_h )
0 commit comments