1
1
from . import _array_module as xp
2
2
3
+
3
4
__all__ = [
4
- "dtype_mapping" ,
5
5
"promotion_table" ,
6
6
"dtype_nbits" ,
7
7
"dtype_signed" ,
14
14
"operators_to_functions" ,
15
15
]
16
16
17
- dtype_mapping = {
18
- 'int8' : xp .int8 ,
19
- 'int16' : xp .int16 ,
20
- 'int32' : xp .int32 ,
21
- 'int64' : xp .int64 ,
22
- 'uint8' : xp .uint8 ,
23
- 'uint16' : xp .uint16 ,
24
- 'uint32' : xp .uint32 ,
25
- 'uint64' : xp .uint64 ,
26
- 'float32' : xp .float32 ,
27
- 'float64' : xp .float64 ,
28
- 'bool' : xp .bool ,
29
- }
30
-
31
- reverse_dtype_mapping = {v : k for k , v in dtype_mapping .items ()}
32
17
33
18
def dtype_nbits (dtype ):
34
19
if dtype == xp .int8 :
@@ -54,79 +39,87 @@ def dtype_nbits(dtype):
54
39
else :
55
40
raise ValueError (f"dtype_nbits is not defined for { dtype } " )
56
41
42
+
57
43
def dtype_signed (dtype ):
58
44
if dtype in [xp .int8 , xp .int16 , xp .int32 , xp .int64 ]:
59
45
return True
60
46
elif dtype in [xp .uint8 , xp .uint16 , xp .uint32 , xp .uint64 ]:
61
47
return False
62
48
raise ValueError ("dtype_signed is only defined for integer dtypes" )
63
49
50
+
64
51
signed_integer_promotion_table = {
65
- (' int8' , ' int8' ): ' int8' ,
66
- (' int8' , ' int16' ): ' int16' ,
67
- (' int8' , ' int32' ): ' int32' ,
68
- (' int8' , ' int64' ): ' int64' ,
69
- (' int16' , ' int8' ): ' int16' ,
70
- (' int16' , ' int16' ): ' int16' ,
71
- (' int16' , ' int32' ): ' int32' ,
72
- (' int16' , ' int64' ): ' int64' ,
73
- (' int32' , ' int8' ): ' int32' ,
74
- (' int32' , ' int16' ): ' int32' ,
75
- (' int32' , ' int32' ): ' int32' ,
76
- (' int32' , ' int64' ): ' int64' ,
77
- (' int64' , ' int8' ): ' int64' ,
78
- (' int64' , ' int16' ): ' int64' ,
79
- (' int64' , ' int32' ): ' int64' ,
80
- (' int64' , ' int64' ): ' int64' ,
52
+ (xp . int8 , xp . int8 ): xp . int8 ,
53
+ (xp . int8 , xp . int16 ): xp . int16 ,
54
+ (xp . int8 , xp . int32 ): xp . int32 ,
55
+ (xp . int8 , xp . int64 ): xp . int64 ,
56
+ (xp . int16 , xp . int8 ): xp . int16 ,
57
+ (xp . int16 , xp . int16 ): xp . int16 ,
58
+ (xp . int16 , xp . int32 ): xp . int32 ,
59
+ (xp . int16 , xp . int64 ): xp . int64 ,
60
+ (xp . int32 , xp . int8 ): xp . int32 ,
61
+ (xp . int32 , xp . int16 ): xp . int32 ,
62
+ (xp . int32 , xp . int32 ): xp . int32 ,
63
+ (xp . int32 , xp . int64 ): xp . int64 ,
64
+ (xp . int64 , xp . int8 ): xp . int64 ,
65
+ (xp . int64 , xp . int16 ): xp . int64 ,
66
+ (xp . int64 , xp . int32 ): xp . int64 ,
67
+ (xp . int64 , xp . int64 ): xp . int64 ,
81
68
}
82
69
70
+
83
71
unsigned_integer_promotion_table = {
84
- (' uint8' , ' uint8' ): ' uint8' ,
85
- (' uint8' , ' uint16' ): ' uint16' ,
86
- (' uint8' , ' uint32' ): ' uint32' ,
87
- (' uint8' , ' uint64' ): ' uint64' ,
88
- (' uint16' , ' uint8' ): ' uint16' ,
89
- (' uint16' , ' uint16' ): ' uint16' ,
90
- (' uint16' , ' uint32' ): ' uint32' ,
91
- (' uint16' , ' uint64' ): ' uint64' ,
92
- (' uint32' , ' uint8' ): ' uint32' ,
93
- (' uint32' , ' uint16' ): ' uint32' ,
94
- (' uint32' , ' uint32' ): ' uint32' ,
95
- (' uint32' , ' uint64' ): ' uint64' ,
96
- (' uint64' , ' uint8' ): ' uint64' ,
97
- (' uint64' , ' uint16' ): ' uint64' ,
98
- (' uint64' , ' uint32' ): ' uint64' ,
99
- (' uint64' , ' uint64' ): ' uint64' ,
72
+ (xp . uint8 , xp . uint8 ): xp . uint8 ,
73
+ (xp . uint8 , xp . uint16 ): xp . uint16 ,
74
+ (xp . uint8 , xp . uint32 ): xp . uint32 ,
75
+ (xp . uint8 , xp . uint64 ): xp . uint64 ,
76
+ (xp . uint16 , xp . uint8 ): xp . uint16 ,
77
+ (xp . uint16 , xp . uint16 ): xp . uint16 ,
78
+ (xp . uint16 , xp . uint32 ): xp . uint32 ,
79
+ (xp . uint16 , xp . uint64 ): xp . uint64 ,
80
+ (xp . uint32 , xp . uint8 ): xp . uint32 ,
81
+ (xp . uint32 , xp . uint16 ): xp . uint32 ,
82
+ (xp . uint32 , xp . uint32 ): xp . uint32 ,
83
+ (xp . uint32 , xp . uint64 ): xp . uint64 ,
84
+ (xp . uint64 , xp . uint8 ): xp . uint64 ,
85
+ (xp . uint64 , xp . uint16 ): xp . uint64 ,
86
+ (xp . uint64 , xp . uint32 ): xp . uint64 ,
87
+ (xp . uint64 , xp . uint64 ): xp . uint64 ,
100
88
}
101
89
90
+
102
91
mixed_signed_unsigned_promotion_table = {
103
- (' int8' , ' uint8' ): ' int16' ,
104
- (' int8' , ' uint16' ): ' int32' ,
105
- (' int8' , ' uint32' ): ' int64' ,
106
- (' int16' , ' uint8' ): ' int16' ,
107
- (' int16' , ' uint16' ): ' int32' ,
108
- (' int16' , ' uint32' ): ' int64' ,
109
- (' int32' , ' uint8' ): ' int32' ,
110
- (' int32' , ' uint16' ): ' int32' ,
111
- (' int32' , ' uint32' ): ' int64' ,
112
- (' int64' , ' uint8' ): ' int64' ,
113
- (' int64' , ' uint16' ): ' int64' ,
114
- (' int64' , ' uint32' ): ' int64' ,
92
+ (xp . int8 , xp . uint8 ): xp . int16 ,
93
+ (xp . int8 , xp . uint16 ): xp . int32 ,
94
+ (xp . int8 , xp . uint32 ): xp . int64 ,
95
+ (xp . int16 , xp . uint8 ): xp . int16 ,
96
+ (xp . int16 , xp . uint16 ): xp . int32 ,
97
+ (xp . int16 , xp . uint32 ): xp . int64 ,
98
+ (xp . int32 , xp . uint8 ): xp . int32 ,
99
+ (xp . int32 , xp . uint16 ): xp . int32 ,
100
+ (xp . int32 , xp . uint32 ): xp . int64 ,
101
+ (xp . int64 , xp . uint8 ): xp . int64 ,
102
+ (xp . int64 , xp . uint16 ): xp . int64 ,
103
+ (xp . int64 , xp . uint32 ): xp . int64 ,
115
104
}
116
105
106
+
117
107
flipped_mixed_signed_unsigned_promotion_table = {(u , i ): p for (i , u ), p in mixed_signed_unsigned_promotion_table .items ()}
118
108
109
+
119
110
float_promotion_table = {
120
- (' float32' , ' float32' ): ' float32' ,
121
- (' float32' , ' float64' ): ' float64' ,
122
- (' float64' , ' float32' ): ' float64' ,
123
- (' float64' , ' float64' ): ' float64' ,
111
+ (xp . float32 , xp . float32 ): xp . float32 ,
112
+ (xp . float32 , xp . float64 ): xp . float64 ,
113
+ (xp . float64 , xp . float32 ): xp . float64 ,
114
+ (xp . float64 , xp . float64 ): xp . float64 ,
124
115
}
125
116
117
+
126
118
boolean_promotion_table = {
127
- (' bool' , ' bool' ): ' bool' ,
119
+ (xp . bool , xp . bool ): xp . bool ,
128
120
}
129
121
122
+
130
123
promotion_table = {
131
124
** signed_integer_promotion_table ,
132
125
** unsigned_integer_promotion_table ,
@@ -136,6 +129,7 @@ def dtype_signed(dtype):
136
129
** boolean_promotion_table ,
137
130
}
138
131
132
+
139
133
input_types = {
140
134
'any' : sorted (set (promotion_table .values ())),
141
135
'boolean' : sorted (set (boolean_promotion_table .values ())),
@@ -150,21 +144,23 @@ def dtype_signed(dtype):
150
144
** unsigned_integer_promotion_table }.values ())),
151
145
}
152
146
147
+
153
148
dtypes_to_scalars = {
154
- ' bool' : [bool ],
155
- ' int8' : [int ],
156
- ' int16' : [int ],
157
- ' int32' : [int ],
158
- ' int64' : [int ],
149
+ xp . bool : [bool ],
150
+ xp . int8 : [int ],
151
+ xp . int16 : [int ],
152
+ xp . int32 : [int ],
153
+ xp . int64 : [int ],
159
154
# Note: unsigned int dtypes only correspond to positive integers
160
- ' uint8' : [int ],
161
- ' uint16' : [int ],
162
- ' uint32' : [int ],
163
- ' uint64' : [int ],
164
- ' float32' : [int , float ],
165
- ' float64' : [int , float ],
155
+ xp . uint8 : [int ],
156
+ xp . uint16 : [int ],
157
+ xp . uint32 : [int ],
158
+ xp . uint64 : [int ],
159
+ xp . float32 : [int , float ],
160
+ xp . float64 : [int , float ],
166
161
}
167
162
163
+
168
164
elementwise_function_input_types = {
169
165
'abs' : 'numeric' ,
170
166
'acos' : 'floating' ,
@@ -224,6 +220,7 @@ def dtype_signed(dtype):
224
220
'trunc' : 'numeric' ,
225
221
}
226
222
223
+
227
224
elementwise_function_output_types = {
228
225
'abs' : 'promoted' ,
229
226
'acos' : 'promoted' ,
@@ -283,6 +280,7 @@ def dtype_signed(dtype):
283
280
'trunc' : 'promoted' ,
284
281
}
285
282
283
+
286
284
binary_operators = {
287
285
'__add__' : '+' ,
288
286
'__and__' : '&' ,
@@ -305,6 +303,7 @@ def dtype_signed(dtype):
305
303
'__xor__' : '^' ,
306
304
}
307
305
306
+
308
307
unary_operators = {
309
308
'__abs__' : 'abs()' ,
310
309
'__invert__' : '~' ,
0 commit comments