50
50
]
51
51
52
52
53
- def _compute_res_dtype (* arrays , sycl_queue , dtype = None , casting = "no" ):
53
+ def _compute_res_dtype (* arrays , sycl_queue , dtype = None , out = None , casting = "no" ):
54
54
"""
55
- Determines the output array data type and an intermediate data type
56
- used in performing calculations related to a specific math function.
57
- If dtype is ``None``, the output array data type of the operation is
58
- determined based on the Promotion Type Rule and device capabilities.
59
- Otherwise, `dtype` is used as output array dtype, if input arrays
60
- can cast to it according to the casting rule determined. If casting
61
- cannot be done, a ``TypeError`` is raised.
62
- The intermediate data type is the data type used for performing the math
63
- function calculations. If output array dtype is a floating-point data type,
64
- it is also used for the intermediate data type. If output array dtype is an
65
- integral data type, the default floating point data type of the device where
66
- input arrays are allocated on are used for intermediate data type.
55
+ Determines the output array data type.
56
+ If `dtype` and `out` are ``None``, the output array data type of the
57
+ operation is determined based on the Promotion Type Rule and device
58
+ capabilities. if `out` is given, its data type is used as the output
59
+ array dtypes. Otherwise, `dtype` is used as output array dtype.
60
+ If input arrays cannot be cast to the determined output array dtype,
61
+ a ``TypeError`` is raised.
67
62
68
63
Parameters
69
64
----------
70
65
arrays : {dpnp.ndarray, usm_ndarray}
71
66
Input arrays.
72
67
dtype : dtype
68
+ If not ``None`` and `out` is not defined, data type of the output array.
69
+ out : {dpnp.ndarray, usm_ndarray}
73
70
If not ``None``, data type of the output array.
74
71
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
75
72
Controls what kind of data casting may occur.
@@ -78,17 +75,23 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
78
75
79
76
Returns
80
77
-------
81
- compute_dtype, res_dtype :
82
- `compute_dtype` is the data type used in performing math function calculations.
83
- The input arrays of the math function are cast to `compute_dtype` and then
84
- the calculations are performed.
85
- `res_dtype` is the output data type. When the result is obtained, it is cast
86
- to `res_dtype`.
78
+ res_dtype :
79
+ `res_dtype` is the output data type. When the result is obtained,
80
+ it is cast to `res_dtype`.
87
81
88
82
"""
89
83
90
84
res_dtype = dpnp .result_type (* arrays )
91
- default_dtype = dpnp .default_float_type (sycl_queue = sycl_queue )
85
+
86
+ # If inputs are boolean and `out` is given and it is not boolean, the
87
+ # calculation should be performed in boolean and at the end the result
88
+ # is cast to out dtype. It is different than general case where the inputs
89
+ # are cast to out dtype and then calculation is performed. Even when inputs
90
+ # are boolean and `dtype` is given, the casting is done first and then the
91
+ # calculation is performed.
92
+ if out is not None and res_dtype != dpnp .bool :
93
+ # out dtype is prioritized over a given dtype
94
+ dtype = out .dtype
92
95
93
96
if dtype is not None :
94
97
if dpnp .can_cast (res_dtype , dtype , casting = casting ):
@@ -98,11 +101,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
98
101
f"Cannot cast from dtype({ res_dtype } ) to dtype({ dtype } ) with casting rule { casting } "
99
102
)
100
103
101
- compute_dtype = (
102
- res_dtype if dpnp .issubdtype (res_dtype , dpnp .inexact ) else default_dtype
103
- )
104
-
105
- return compute_dtype , res_dtype
104
+ return res_dtype
106
105
107
106
108
107
def _copy_array (x , copy_flag = False , dtype = None , order = "C" ):
@@ -504,6 +503,23 @@ def _gemm_matmul(exec_q, x1, x2, res):
504
503
return res
505
504
506
505
506
+ def _gemm_special_case (x1 , x2 , res_dtype , call_flag ):
507
+ """
508
+ `gemm` and `gemm_batch` support these special cases of data types
509
+ while `gemv` does not.
510
+
511
+ """
512
+ # TODO: replace with dpnp.int8 when it is added
513
+ is_int8 = x1 .dtype == numpy .int8 and x2 .dtype == numpy .int8
514
+ is_int32_or_f32 = res_dtype in [dpnp .int32 , dpnp .float32 ]
515
+ flag = is_int8 and is_int32_or_f32 and call_flag in ["gemm" , "gemm_batch" ]
516
+
517
+ # onemkl_interfaces does not support these data types
518
+ onemkl_interfaces = bi ._using_onemkl_interfaces ()
519
+
520
+ return flag and not onemkl_interfaces
521
+
522
+
507
523
def _shape_error (shape1 , shape2 , func , err_msg ):
508
524
"""Validate the shapes of input and output arrays."""
509
525
@@ -749,17 +765,19 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
749
765
_validate_out_array (out , exec_q )
750
766
751
767
# Determine the appropriate data types
752
- dot_dtype , res_dtype = _compute_res_dtype (a , b , sycl_queue = exec_q )
768
+ res_dtype = _compute_res_dtype (
769
+ a , b , out = out , casting = casting , sycl_queue = exec_q
770
+ )
753
771
754
772
result = _create_result_array (
755
- a , b , out , (), dot_dtype , res_usm_type , exec_q
773
+ a , b , out , (), res_dtype , res_usm_type , exec_q
756
774
)
757
775
758
776
# input arrays should have the proper data type
759
777
if dpnp .issubdtype (res_dtype , dpnp .inexact ):
760
778
# copying is needed if dtypes of input arrays are different
761
- a = _copy_array (a , dtype = dot_dtype )
762
- b = _copy_array (b , dtype = dot_dtype )
779
+ a = _copy_array (a , dtype = res_dtype )
780
+ b = _copy_array (b , dtype = res_dtype )
763
781
764
782
_manager = dpu .SequentialOrderManager [exec_q ]
765
783
@@ -777,14 +795,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
777
795
)
778
796
_manager .add_event_pair (ht_ev , dot_ev )
779
797
else :
780
- # oneapi::mkl::blas::dot is slow for integer data type ,
798
+ # oneapi::mkl::blas::dot does not support integer dtypes ,
781
799
# so using dpctl.tensor.vecdot instead
782
- dpt_a = dpnp .get_usm_ndarray (a )
783
- dpt_b = dpnp .get_usm_ndarray (b )
784
- result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (dpt_a , dpt_b ))
785
-
786
- if dot_dtype != res_dtype :
787
- result = result .astype (res_dtype , copy = False )
800
+ a_usm = dpnp .get_usm_ndarray (a )
801
+ b_usm = dpnp .get_usm_ndarray (b )
802
+ result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (a_usm , b_usm ))
788
803
789
804
return dpnp .get_result_array (result , out , casting = casting )
790
805
@@ -902,8 +917,8 @@ def dpnp_multiplication(
902
917
axes_res = normalize_axis_tuple (axes_res , len (result_shape ), "axes" )
903
918
904
919
# Determine the appropriate data types
905
- compute_dtype , res_dtype = _compute_res_dtype (
906
- x1 , x2 , dtype = dtype , casting = casting , sycl_queue = exec_q
920
+ res_dtype = _compute_res_dtype (
921
+ x1 , x2 , dtype = dtype , out = out , casting = casting , sycl_queue = exec_q
907
922
)
908
923
909
924
call_flag = None
@@ -998,7 +1013,7 @@ def dpnp_multiplication(
998
1013
x2 ,
999
1014
out ,
1000
1015
res_shape ,
1001
- compute_dtype ,
1016
+ res_dtype ,
1002
1017
res_usm_type ,
1003
1018
exec_q ,
1004
1019
res_order ,
@@ -1010,64 +1025,82 @@ def dpnp_multiplication(
1010
1025
elif x1 .size == 0 or x2 .size == 0 :
1011
1026
result .fill (0 )
1012
1027
else :
1013
- # input arrays should have the proper data type and
1014
- # their base (last 2-dimensions) to be c-contiguous or f-contiguous
1015
- x1 = _copy_array (
1016
- x1 ,
1017
- copy_flag = not x1_contig_flag ,
1018
- dtype = compute_dtype ,
1019
- order = res_order ,
1020
- )
1021
- x2 = _copy_array (
1022
- x2 ,
1023
- copy_flag = not x2_contig_flag ,
1024
- dtype = compute_dtype ,
1025
- order = res_order ,
1026
- )
1027
-
1028
- if call_flag == "gemv" :
1029
- if transpose :
1030
- a_usm = dpnp .get_usm_ndarray (x2 )
1031
- x_usm = dpnp .get_usm_ndarray (x1 )
1032
- else :
1033
- a_usm = dpnp .get_usm_ndarray (x1 )
1034
- x_usm = dpnp .get_usm_ndarray (x2 )
1035
-
1036
- _manager = dpu .SequentialOrderManager [exec_q ]
1037
-
1038
- ht_ev , gemv_ev = bi ._gemv (
1039
- exec_q ,
1040
- a_usm ,
1041
- x_usm ,
1042
- dpnp .get_usm_ndarray (result ),
1043
- transpose ,
1044
- depends = _manager .submitted_events ,
1028
+ if _gemm_special_case (x1 , x2 , res_dtype , call_flag ):
1029
+ x1 = _copy_array (
1030
+ x1 , copy_flag = not x1_contig_flag , order = res_order
1045
1031
)
1046
- _manager .add_event_pair (ht_ev , gemv_ev )
1047
- elif call_flag == "gemm" :
1048
- result = _gemm_matmul (
1049
- exec_q ,
1050
- x1 ,
1051
- x2 ,
1052
- result ,
1032
+ x2 = _copy_array (
1033
+ x2 , copy_flag = not x2_contig_flag , order = res_order
1053
1034
)
1054
- else : # call_flag == "gemm_batch"
1055
- assert call_flag == "gemm_batch"
1056
- result = _gemm_batch_matmul (
1057
- exec_q ,
1035
+ if call_flag == "gemm" :
1036
+ result = _gemm_matmul (exec_q , x1 , x2 , result )
1037
+ else :
1038
+ assert call_flag == "gemm_batch"
1039
+ result = _gemm_batch_matmul (exec_q , x1 , x2 , result )
1040
+ elif dpnp .issubdtype (res_dtype , dpnp .inexact ):
1041
+ # copying is needed if dtypes of input arrays are different or
1042
+ # their base (last 2-dimensions) is not c-contiguous or f-contiguous
1043
+ x1 = _copy_array (
1058
1044
x1 ,
1045
+ copy_flag = not x1_contig_flag ,
1046
+ dtype = res_dtype ,
1047
+ order = res_order ,
1048
+ )
1049
+ x2 = _copy_array (
1059
1050
x2 ,
1060
- result ,
1051
+ copy_flag = not x2_contig_flag ,
1052
+ dtype = res_dtype ,
1053
+ order = res_order ,
1054
+ )
1055
+
1056
+ if call_flag == "gemv" :
1057
+ if transpose :
1058
+ a_usm = dpnp .get_usm_ndarray (x2 )
1059
+ x_usm = dpnp .get_usm_ndarray (x1 )
1060
+ else :
1061
+ a_usm = dpnp .get_usm_ndarray (x1 )
1062
+ x_usm = dpnp .get_usm_ndarray (x2 )
1063
+
1064
+ _manager = dpu .SequentialOrderManager [exec_q ]
1065
+
1066
+ ht_ev , gemv_ev = bi ._gemv (
1067
+ exec_q ,
1068
+ a_usm ,
1069
+ x_usm ,
1070
+ dpnp .get_usm_ndarray (result ),
1071
+ transpose ,
1072
+ depends = _manager .submitted_events ,
1073
+ )
1074
+ _manager .add_event_pair (ht_ev , gemv_ev )
1075
+ elif call_flag == "gemm" :
1076
+ result = _gemm_matmul (exec_q , x1 , x2 , result )
1077
+ else :
1078
+ assert call_flag == "gemm_batch"
1079
+ result = _gemm_batch_matmul (exec_q , x1 , x2 , result )
1080
+ else :
1081
+ # oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1082
+ # except for special cases determined in `_gemm_special_case`,
1083
+ # use dpctl.tensor.matmul for unsupported cases
1084
+
1085
+ # `dpt.matmul` does not support `casting` kwarg.
1086
+ # We may need to change input dtypes based on given `casting`.
1087
+ # The possibility of casting is already validated in
1088
+ # `_compute_res_dtype`.
1089
+ x1 = _copy_array (x1 , dtype = res_dtype , order = res_order )
1090
+ x2 = _copy_array (x2 , dtype = res_dtype , order = res_order )
1091
+
1092
+ x1_usm = dpnp .get_usm_ndarray (x1 )
1093
+ x2_usm = dpnp .get_usm_ndarray (x2 )
1094
+ out_usm = dpnp .get_usm_ndarray (result )
1095
+ dpt .matmul (
1096
+ x1_usm , x2_usm , out = out_usm , dtype = dtype , order = order
1061
1097
)
1062
1098
1063
1099
if NumPy_special_case :
1064
1100
result = dpnp .tile (result , out .shape )
1065
1101
elif res_shape != result_shape :
1066
1102
result = dpnp .reshape (result , result_shape )
1067
1103
1068
- if compute_dtype != res_dtype :
1069
- result = dpnp .astype (result , res_dtype , copy = False )
1070
-
1071
1104
if out is None :
1072
1105
if axes is not None :
1073
1106
# Move the data back to the appropriate axes of the result array
@@ -1207,8 +1240,8 @@ def dpnp_vecdot(
1207
1240
)
1208
1241
1209
1242
# Determine the appropriate data types
1210
- _ , res_dtype = _compute_res_dtype (
1211
- x1 , x2 , dtype = dtype , casting = casting , sycl_queue = exec_q
1243
+ res_dtype = _compute_res_dtype (
1244
+ x1 , x2 , dtype = dtype , out = out , casting = casting , sycl_queue = exec_q
1212
1245
)
1213
1246
1214
1247
_ , x1_is_1D , _ = _define_dim_flags (x1 , axis = - 1 )
0 commit comments