Skip to content

Commit db97d59

Browse files
authored
Use dpctl.tensor.matmul in the backend of dpnp.matmul when inputs are integer (#2296)
resolves #2270 OneMath (OneMKL) routines (`gemm`, `gemv`, `gemm_batch`) for matrix multiplication only support floating point data types. If inputs are integer, to use OneMath we need to upcasting them to floating point dtypes, perform the calculation and then cast back the result to integer dtypes which is unsafe and we may loose some information for large integers. In this PR, the logic for `dpnp.matmul` is updated to use `dpctl.tensor.matmul` when result has a integer dtypes.
1 parent 0c455a6 commit db97d59

File tree

7 files changed

+299
-269
lines changed

7 files changed

+299
-269
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@ array_api_tests/test_linalg.py::test_svd
2626
array_api_tests/test_linalg.py::test_qr
2727
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
2828

29-
# unexpected result is returned
29+
# unexpected result is returned - unmute when dpctl-1986 is resolved
3030
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
3131
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
3232

3333
# missing 'correction' keyword argument
3434
array_api_tests/test_signatures.py::test_func_signature[std]
3535
array_api_tests/test_signatures.py::test_func_signature[var]
36-
37-
# arrays have different values
38-
array_api_tests/test_linalg.py::test_linalg_tensordot

.github/workflows/check-mkl-interfaces.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ jobs:
216216
id: run_tests
217217
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
218218
with:
219-
timeout_minutes: 12
219+
timeout_minutes: 15
220220
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
221221
retry_on: any
222222
command: |

.github/workflows/conda-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ jobs:
218218
id: run_tests_linux
219219
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
220220
with:
221-
timeout_minutes: 12
221+
timeout_minutes: 15
222222
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
223223
retry_on: any
224224
command: |
@@ -355,7 +355,7 @@ jobs:
355355
id: run_tests_win
356356
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
357357
with:
358-
timeout_minutes: 15
358+
timeout_minutes: 17
359359
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
360360
retry_on: any
361361
command: |

.github/workflows/cron-run-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ jobs:
126126
id: run_tests_linux
127127
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
128128
with:
129-
timeout_minutes: 12
129+
timeout_minutes: 15
130130
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
131131
retry_on: any
132132
command: |
@@ -143,7 +143,7 @@ jobs:
143143
id: run_tests_win
144144
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
145145
with:
146-
timeout_minutes: 15
146+
timeout_minutes: 17
147147
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
148148
retry_on: any
149149
command: |

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,18 @@ PYBIND11_MODULE(_blas_impl, m)
142142
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
143143
py::arg("vectorY"), py::arg("transpose"),
144144
py::arg("depends") = py::list());
145+
}
146+
147+
{
145148
m.def(
146-
"_row_major_is_available",
147-
[](void) {
148-
#if defined(USE_ONEMKL_CUBLAS)
149-
return false;
150-
#else
149+
"_using_onemkl_interfaces",
150+
[]() {
151+
#ifdef USE_ONEMKL_INTERFACES
151152
return true;
152-
#endif // USE_ONEMKL_CUBLAS
153+
#else
154+
return false;
155+
#endif
153156
},
154-
"Check if the onemkl::blas::row_major can be used.");
157+
"Check if the OneMKL interfaces are being used.");
155158
}
156159
}

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 121 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,23 @@
5050
]
5151

5252

53-
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
53+
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"):
5454
"""
55-
Determines the output array data type and an intermediate data type
56-
used in performing calculations related to a specific math function.
57-
If dtype is ``None``, the output array data type of the operation is
58-
determined based on the Promotion Type Rule and device capabilities.
59-
Otherwise, `dtype` is used as output array dtype, if input arrays
60-
can cast to it according to the casting rule determined. If casting
61-
cannot be done, a ``TypeError`` is raised.
62-
The intermediate data type is the data type used for performing the math
63-
function calculations. If output array dtype is a floating-point data type,
64-
it is also used for the intermediate data type. If output array dtype is an
65-
integral data type, the default floating point data type of the device where
66-
input arrays are allocated on are used for intermediate data type.
55+
Determines the output array data type.
56+
If `dtype` and `out` are ``None``, the output array data type of the
57+
operation is determined based on the Promotion Type Rule and device
58+
capabilities. if `out` is given, its data type is used as the output
59+
array dtypes. Otherwise, `dtype` is used as output array dtype.
60+
If input arrays cannot be cast to the determined output array dtype,
61+
a ``TypeError`` is raised.
6762
6863
Parameters
6964
----------
7065
arrays : {dpnp.ndarray, usm_ndarray}
7166
Input arrays.
7267
dtype : dtype
68+
If not ``None`` and `out` is not defined, data type of the output array.
69+
out : {dpnp.ndarray, usm_ndarray}
7370
If not ``None``, data type of the output array.
7471
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
7572
Controls what kind of data casting may occur.
@@ -78,17 +75,23 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
7875
7976
Returns
8077
-------
81-
compute_dtype, res_dtype :
82-
`compute_dtype` is the data type used in performing math function calculations.
83-
The input arrays of the math function are cast to `compute_dtype` and then
84-
the calculations are performed.
85-
`res_dtype` is the output data type. When the result is obtained, it is cast
86-
to `res_dtype`.
78+
res_dtype :
79+
`res_dtype` is the output data type. When the result is obtained,
80+
it is cast to `res_dtype`.
8781
8882
"""
8983

9084
res_dtype = dpnp.result_type(*arrays)
91-
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)
85+
86+
# If inputs are boolean and `out` is given and it is not boolean, the
87+
# calculation should be performed in boolean and at the end the result
88+
# is cast to out dtype. It is different than general case where the inputs
89+
# are cast to out dtype and then calculation is performed. Even when inputs
90+
# are boolean and `dtype` is given, the casting is done first and then the
91+
# calculation is performed.
92+
if out is not None and res_dtype != dpnp.bool:
93+
# out dtype is prioritized over a given dtype
94+
dtype = out.dtype
9295

9396
if dtype is not None:
9497
if dpnp.can_cast(res_dtype, dtype, casting=casting):
@@ -98,11 +101,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
98101
f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
99102
)
100103

101-
compute_dtype = (
102-
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
103-
)
104-
105-
return compute_dtype, res_dtype
104+
return res_dtype
106105

107106

108107
def _copy_array(x, copy_flag=False, dtype=None, order="C"):
@@ -504,6 +503,23 @@ def _gemm_matmul(exec_q, x1, x2, res):
504503
return res
505504

506505

506+
def _gemm_special_case(x1, x2, res_dtype, call_flag):
507+
"""
508+
`gemm` and `gemm_batch` support these special cases of data types
509+
while `gemv` does not.
510+
511+
"""
512+
# TODO: replace with dpnp.int8 when it is added
513+
is_int8 = x1.dtype == numpy.int8 and x2.dtype == numpy.int8
514+
is_int32_or_f32 = res_dtype in [dpnp.int32, dpnp.float32]
515+
flag = is_int8 and is_int32_or_f32 and call_flag in ["gemm", "gemm_batch"]
516+
517+
# onemkl_interfaces does not support these data types
518+
onemkl_interfaces = bi._using_onemkl_interfaces()
519+
520+
return flag and not onemkl_interfaces
521+
522+
507523
def _shape_error(shape1, shape2, func, err_msg):
508524
"""Validate the shapes of input and output arrays."""
509525

@@ -749,17 +765,19 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
749765
_validate_out_array(out, exec_q)
750766

751767
# Determine the appropriate data types
752-
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
768+
res_dtype = _compute_res_dtype(
769+
a, b, out=out, casting=casting, sycl_queue=exec_q
770+
)
753771

754772
result = _create_result_array(
755-
a, b, out, (), dot_dtype, res_usm_type, exec_q
773+
a, b, out, (), res_dtype, res_usm_type, exec_q
756774
)
757775

758776
# input arrays should have the proper data type
759777
if dpnp.issubdtype(res_dtype, dpnp.inexact):
760778
# copying is needed if dtypes of input arrays are different
761-
a = _copy_array(a, dtype=dot_dtype)
762-
b = _copy_array(b, dtype=dot_dtype)
779+
a = _copy_array(a, dtype=res_dtype)
780+
b = _copy_array(b, dtype=res_dtype)
763781

764782
_manager = dpu.SequentialOrderManager[exec_q]
765783

@@ -777,14 +795,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
777795
)
778796
_manager.add_event_pair(ht_ev, dot_ev)
779797
else:
780-
# oneapi::mkl::blas::dot is slow for integer data type,
798+
# oneapi::mkl::blas::dot does not support integer dtypes,
781799
# so using dpctl.tensor.vecdot instead
782-
dpt_a = dpnp.get_usm_ndarray(a)
783-
dpt_b = dpnp.get_usm_ndarray(b)
784-
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b))
785-
786-
if dot_dtype != res_dtype:
787-
result = result.astype(res_dtype, copy=False)
800+
a_usm = dpnp.get_usm_ndarray(a)
801+
b_usm = dpnp.get_usm_ndarray(b)
802+
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(a_usm, b_usm))
788803

789804
return dpnp.get_result_array(result, out, casting=casting)
790805

@@ -902,8 +917,8 @@ def dpnp_multiplication(
902917
axes_res = normalize_axis_tuple(axes_res, len(result_shape), "axes")
903918

904919
# Determine the appropriate data types
905-
compute_dtype, res_dtype = _compute_res_dtype(
906-
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
920+
res_dtype = _compute_res_dtype(
921+
x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q
907922
)
908923

909924
call_flag = None
@@ -998,7 +1013,7 @@ def dpnp_multiplication(
9981013
x2,
9991014
out,
10001015
res_shape,
1001-
compute_dtype,
1016+
res_dtype,
10021017
res_usm_type,
10031018
exec_q,
10041019
res_order,
@@ -1010,64 +1025,82 @@ def dpnp_multiplication(
10101025
elif x1.size == 0 or x2.size == 0:
10111026
result.fill(0)
10121027
else:
1013-
# input arrays should have the proper data type and
1014-
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
1015-
x1 = _copy_array(
1016-
x1,
1017-
copy_flag=not x1_contig_flag,
1018-
dtype=compute_dtype,
1019-
order=res_order,
1020-
)
1021-
x2 = _copy_array(
1022-
x2,
1023-
copy_flag=not x2_contig_flag,
1024-
dtype=compute_dtype,
1025-
order=res_order,
1026-
)
1027-
1028-
if call_flag == "gemv":
1029-
if transpose:
1030-
a_usm = dpnp.get_usm_ndarray(x2)
1031-
x_usm = dpnp.get_usm_ndarray(x1)
1032-
else:
1033-
a_usm = dpnp.get_usm_ndarray(x1)
1034-
x_usm = dpnp.get_usm_ndarray(x2)
1035-
1036-
_manager = dpu.SequentialOrderManager[exec_q]
1037-
1038-
ht_ev, gemv_ev = bi._gemv(
1039-
exec_q,
1040-
a_usm,
1041-
x_usm,
1042-
dpnp.get_usm_ndarray(result),
1043-
transpose,
1044-
depends=_manager.submitted_events,
1028+
if _gemm_special_case(x1, x2, res_dtype, call_flag):
1029+
x1 = _copy_array(
1030+
x1, copy_flag=not x1_contig_flag, order=res_order
10451031
)
1046-
_manager.add_event_pair(ht_ev, gemv_ev)
1047-
elif call_flag == "gemm":
1048-
result = _gemm_matmul(
1049-
exec_q,
1050-
x1,
1051-
x2,
1052-
result,
1032+
x2 = _copy_array(
1033+
x2, copy_flag=not x2_contig_flag, order=res_order
10531034
)
1054-
else: # call_flag == "gemm_batch"
1055-
assert call_flag == "gemm_batch"
1056-
result = _gemm_batch_matmul(
1057-
exec_q,
1035+
if call_flag == "gemm":
1036+
result = _gemm_matmul(exec_q, x1, x2, result)
1037+
else:
1038+
assert call_flag == "gemm_batch"
1039+
result = _gemm_batch_matmul(exec_q, x1, x2, result)
1040+
elif dpnp.issubdtype(res_dtype, dpnp.inexact):
1041+
# copying is needed if dtypes of input arrays are different or
1042+
# their base (last 2-dimensions) is not c-contiguous or f-contiguous
1043+
x1 = _copy_array(
10581044
x1,
1045+
copy_flag=not x1_contig_flag,
1046+
dtype=res_dtype,
1047+
order=res_order,
1048+
)
1049+
x2 = _copy_array(
10591050
x2,
1060-
result,
1051+
copy_flag=not x2_contig_flag,
1052+
dtype=res_dtype,
1053+
order=res_order,
1054+
)
1055+
1056+
if call_flag == "gemv":
1057+
if transpose:
1058+
a_usm = dpnp.get_usm_ndarray(x2)
1059+
x_usm = dpnp.get_usm_ndarray(x1)
1060+
else:
1061+
a_usm = dpnp.get_usm_ndarray(x1)
1062+
x_usm = dpnp.get_usm_ndarray(x2)
1063+
1064+
_manager = dpu.SequentialOrderManager[exec_q]
1065+
1066+
ht_ev, gemv_ev = bi._gemv(
1067+
exec_q,
1068+
a_usm,
1069+
x_usm,
1070+
dpnp.get_usm_ndarray(result),
1071+
transpose,
1072+
depends=_manager.submitted_events,
1073+
)
1074+
_manager.add_event_pair(ht_ev, gemv_ev)
1075+
elif call_flag == "gemm":
1076+
result = _gemm_matmul(exec_q, x1, x2, result)
1077+
else:
1078+
assert call_flag == "gemm_batch"
1079+
result = _gemm_batch_matmul(exec_q, x1, x2, result)
1080+
else:
1081+
# oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1082+
# except for special cases determined in `_gemm_special_case`,
1083+
# use dpctl.tensor.matmul for unsupported cases
1084+
1085+
# `dpt.matmul` does not support `casting` kwarg.
1086+
# We may need to change input dtypes based on given `casting`.
1087+
# The possibility of casting is already validated in
1088+
# `_compute_res_dtype`.
1089+
x1 = _copy_array(x1, dtype=res_dtype, order=res_order)
1090+
x2 = _copy_array(x2, dtype=res_dtype, order=res_order)
1091+
1092+
x1_usm = dpnp.get_usm_ndarray(x1)
1093+
x2_usm = dpnp.get_usm_ndarray(x2)
1094+
out_usm = dpnp.get_usm_ndarray(result)
1095+
dpt.matmul(
1096+
x1_usm, x2_usm, out=out_usm, dtype=dtype, order=order
10611097
)
10621098

10631099
if NumPy_special_case:
10641100
result = dpnp.tile(result, out.shape)
10651101
elif res_shape != result_shape:
10661102
result = dpnp.reshape(result, result_shape)
10671103

1068-
if compute_dtype != res_dtype:
1069-
result = dpnp.astype(result, res_dtype, copy=False)
1070-
10711104
if out is None:
10721105
if axes is not None:
10731106
# Move the data back to the appropriate axes of the result array
@@ -1207,8 +1240,8 @@ def dpnp_vecdot(
12071240
)
12081241

12091242
# Determine the appropriate data types
1210-
_, res_dtype = _compute_res_dtype(
1211-
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
1243+
res_dtype = _compute_res_dtype(
1244+
x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q
12121245
)
12131246

12141247
_, x1_is_1D, _ = _define_dim_flags(x1, axis=-1)

0 commit comments

Comments
 (0)