Skip to content

Commit 4875e59

Browse files
Update dpnp.linalg.svd() to run on CUDA (#2212)
This PR suggests updating `dpnp.linagl.svd()` implementation to support running on CUDA devices. Since cuSolver gesvd only supports m>=n the previous implementation crashed with `Segmentation fault (core dumped)` This suggests adding checks for `m>=n` otherwise transpose the input array. Passing the transposed array to `oneapi::mkl::lapack::gesvd` increases the performance of `dpnp.linalg.svd()` due to the reducing a matrix with `m >= n` to bidiagonal form (inside `lapack::gesvd`) is more efficient
1 parent e0c9cf1 commit 4875e59

File tree

2 files changed

+78
-53
lines changed

2 files changed

+78
-53
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def _batched_qr(a, mode="reduced"):
475475
)
476476

477477

478+
# pylint: disable=too-many-locals
478479
def _batched_svd(
479480
a,
480481
uv_type,
@@ -532,29 +533,30 @@ def _batched_svd(
532533
batch_shape_orig,
533534
)
534535

535-
k = min(m, n)
536-
if compute_uv:
537-
if full_matrices:
538-
u_shape = (m, m) + (batch_size,)
539-
vt_shape = (n, n) + (batch_size,)
540-
jobu = ord("A")
541-
jobvt = ord("A")
542-
else:
543-
u_shape = (m, k) + (batch_size,)
544-
vt_shape = (k, n) + (batch_size,)
545-
jobu = ord("S")
546-
jobvt = ord("S")
536+
# Transpose if m < n:
537+
# 1. cuSolver gesvd supports only m >= n
538+
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
539+
if m < n:
540+
n, m = a.shape[-2:]
541+
trans_flag = True
547542
else:
548-
u_shape = vt_shape = ()
549-
jobu = ord("N")
550-
jobvt = ord("N")
543+
trans_flag = False
544+
545+
u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
546+
m, n, compute_uv, full_matrices, batch_size=batch_size
547+
)
551548

552549
_manager = dpu.SequentialOrderManager[exec_q]
553550
dep_evs = _manager.submitted_events
554551

555552
# Reorder the elements by moving the last two axes of `a` to the front
556553
# to match fortran-like array order which is assumed by gesvd.
557-
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
554+
if trans_flag:
555+
# Transpose axes for cuSolver and to optimize reduction
556+
# to bidiagonal form
557+
a = dpnp.moveaxis(a, (-1, -2), (0, 1))
558+
else:
559+
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
558560

559561
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
560562
# as input.
@@ -583,7 +585,7 @@ def _batched_svd(
583585
sycl_queue=exec_q,
584586
)
585587
s_h = dpnp.empty(
586-
(batch_size,) + (k,),
588+
s_shape,
587589
dtype=s_type,
588590
order="C",
589591
usm_type=usm_type,
@@ -607,16 +609,23 @@ def _batched_svd(
607609
# gesvd call writes `u_h` and `vt_h` in Fortran order;
608610
# reorder the axes to match C order by moving the last axis
609611
# to the front
610-
u = dpnp.moveaxis(u_h, -1, 0)
611-
vt = dpnp.moveaxis(vt_h, -1, 0)
612+
if trans_flag:
613+
# Transpose axes to restore U and V^T for the original matrix
614+
u = dpnp.moveaxis(u_h, (0, -1), (-1, 0))
615+
vt = dpnp.moveaxis(vt_h, (0, -1), (-1, 0))
616+
else:
617+
u = dpnp.moveaxis(u_h, -1, 0)
618+
vt = dpnp.moveaxis(vt_h, -1, 0)
619+
612620
if a_ndim > 3:
613621
u = u.reshape(batch_shape_orig + u.shape[-2:])
614622
vt = vt.reshape(batch_shape_orig + vt.shape[-2:])
615623
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
616624
# Convert to contiguous to align with NumPy
617625
u = dpnp.ascontiguousarray(u)
618626
vt = dpnp.ascontiguousarray(vt)
619-
return u, s, vt
627+
# Swap `u` and `vt` for transposed input to restore correct order
628+
return (vt, s, u) if trans_flag else (u, s, vt)
620629
return s
621630

622631

@@ -759,6 +768,36 @@ def _common_inexact_type(default_dtype, *dtypes):
759768
return dpnp.result_type(*inexact_dtypes)
760769

761770

771+
def _get_svd_shapes_and_flags(m, n, compute_uv, full_matrices, batch_size=None):
772+
"""Return the shapes and flags for SVD computations."""
773+
774+
k = min(m, n)
775+
if compute_uv:
776+
if full_matrices:
777+
u_shape = (m, m)
778+
vt_shape = (n, n)
779+
jobu = ord("A")
780+
jobvt = ord("A")
781+
else:
782+
u_shape = (m, k)
783+
vt_shape = (k, n)
784+
jobu = ord("S")
785+
jobvt = ord("S")
786+
else:
787+
u_shape = vt_shape = ()
788+
jobu = ord("N")
789+
jobvt = ord("N")
790+
791+
s_shape = (k,)
792+
if batch_size is not None:
793+
if compute_uv:
794+
u_shape += (batch_size,)
795+
vt_shape += (batch_size,)
796+
s_shape = (batch_size,) + s_shape
797+
798+
return u_shape, vt_shape, s_shape, jobu, jobvt
799+
800+
762801
def _hermitian_svd(a, compute_uv):
763802
"""
764803
_hermitian_svd(a, compute_uv)
@@ -2695,6 +2734,16 @@ def dpnp_svd(
26952734
a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type
26962735
)
26972736

2737+
# Transpose if m < n:
2738+
# 1. cuSolver gesvd supports only m >= n
2739+
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
2740+
if m < n:
2741+
n, m = a.shape
2742+
a = a.transpose()
2743+
trans_flag = True
2744+
else:
2745+
trans_flag = False
2746+
26982747
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
26992748
# Allocate 'F' order memory for dpnp arrays to comply with
27002749
# these requirements.
@@ -2716,22 +2765,9 @@ def dpnp_svd(
27162765
)
27172766
_manager.add_event_pair(ht_ev, copy_ev)
27182767

2719-
k = min(m, n)
2720-
if compute_uv:
2721-
if full_matrices:
2722-
u_shape = (m, m)
2723-
vt_shape = (n, n)
2724-
jobu = ord("A")
2725-
jobvt = ord("A")
2726-
else:
2727-
u_shape = (m, k)
2728-
vt_shape = (k, n)
2729-
jobu = ord("S")
2730-
jobvt = ord("S")
2731-
else:
2732-
u_shape = vt_shape = ()
2733-
jobu = ord("N")
2734-
jobvt = ord("N")
2768+
u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
2769+
m, n, compute_uv, full_matrices
2770+
)
27352771

27362772
# oneMKL LAPACK assumes fortran-like array as input.
27372773
# Allocate 'F' order memory for dpnp output arrays to comply with
@@ -2746,7 +2782,7 @@ def dpnp_svd(
27462782
shape=vt_shape,
27472783
order="F",
27482784
)
2749-
s_h = dpnp.empty_like(a_h, shape=(k,), dtype=s_type)
2785+
s_h = dpnp.empty_like(a_h, shape=s_shape, dtype=s_type)
27502786

27512787
ht_ev, gesvd_ev = li._gesvd(
27522788
exec_q,
@@ -2761,6 +2797,11 @@ def dpnp_svd(
27612797
_manager.add_event_pair(ht_ev, gesvd_ev)
27622798

27632799
if compute_uv:
2800+
# Transposing the input matrix swaps the roles of U and Vt:
2801+
# For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
2802+
# Transpose and swap them back to restore correct order for A.
2803+
if trans_flag:
2804+
return vt_h.T, s_h, u_h.T
27642805
# gesvd call writes `u_h` and `vt_h` in Fortran order;
27652806
# Convert to contiguous to align with NumPy
27662807
u_h = dpnp.ascontiguousarray(u_h)

dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dpnp.tests.helper import (
88
has_support_aspect64,
99
is_cpu_device,
10-
is_win_platform,
1110
)
1211
from dpnp.tests.third_party.cupy import testing
1312
from dpnp.tests.third_party.cupy.testing import _condition
@@ -280,12 +279,6 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp):
280279
array, full_matrices=self.full_matrices, compute_uv=False
281280
)
282281

283-
# The issue was expected to be resolved once CMPLRLLVM-53771 is available,
284-
# which has to be included in DPC++ 2024.1.0, but problem still exists
285-
# on Windows
286-
@pytest.mark.skipif(
287-
is_cpu_device() and is_win_platform(), reason="SAT-7145"
288-
)
289282
@_condition.repeat(3, 10)
290283
def test_svd_rank3(self):
291284
self.check_usv((2, 3, 4))
@@ -295,9 +288,6 @@ def test_svd_rank3(self):
295288
self.check_usv((2, 4, 3))
296289
self.check_usv((2, 32, 32))
297290

298-
@pytest.mark.skipif(
299-
is_cpu_device() and is_win_platform(), reason="SAT-7145"
300-
)
301291
@_condition.repeat(3, 10)
302292
def test_svd_rank3_loop(self):
303293
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
@@ -345,9 +335,6 @@ def test_svd_rank3_empty_array_compute_uv_false2(self, xp):
345335
array, full_matrices=self.full_matrices, compute_uv=False
346336
)
347337

348-
@pytest.mark.skipif(
349-
is_cpu_device() and is_win_platform(), reason="SAT-7145"
350-
)
351338
@_condition.repeat(3, 10)
352339
def test_svd_rank4(self):
353340
self.check_usv((2, 2, 3, 4))
@@ -357,9 +344,6 @@ def test_svd_rank4(self):
357344
self.check_usv((2, 2, 4, 3))
358345
self.check_usv((2, 2, 32, 32))
359346

360-
@pytest.mark.skipif(
361-
is_cpu_device() and is_win_platform(), reason="SAT-7145"
362-
)
363347
@_condition.repeat(3, 10)
364348
def test_svd_rank4_loop(self):
365349
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)

0 commit comments

Comments
 (0)