Skip to content

Commit 9ec75a5

Browse files
committed
Use dtype object instead of name for dtype helpers
1 parent 29cec27 commit 9ec75a5

File tree

5 files changed

+117
-143
lines changed

5 files changed

+117
-143
lines changed

array_api_tests/array_helpers.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
_numeric_dtypes, _boolean_dtypes, _dtypes,
1010
asarray)
1111
from . import _array_module
12-
from .dtype_helpers import dtype_mapping, promotion_table
12+
from .dtype_helpers import promotion_table
1313

1414
# These are exported here so that they can be included in the special cases
1515
# tests from this file.
@@ -371,14 +371,10 @@ def promote_dtypes(dtype1, dtype2):
371371
Special case of result_type() which uses the exact type promotion table
372372
from the spec.
373373
"""
374-
# Equivalent to this, but some libraries may not work properly with using
375-
# dtype objects as dict keys
376-
#
377-
# d1, d2 = reverse_dtype_mapping[dtype1], reverse_dtype_mapping[dtype2]
378-
379-
d1 = [i for i in dtype_mapping if dtype_mapping[i] == dtype1][0]
380-
d2 = [i for i in dtype_mapping if dtype_mapping[i] == dtype2][0]
381-
382-
if (d1, d2) not in promotion_table:
383-
raise ValueError(f"{d1} and {d2} are not type promotable according to the spec (this may indicate a bug in the test suite).")
384-
return dtype_mapping[promotion_table[d1, d2]]
374+
try:
375+
return promotion_table[(dtype2, dtype2)]
376+
except KeyError as e:
377+
raise ValueError(
378+
f"{dtype1} and {dtype2} are not type promotable according to the spec"
379+
f"(this may indicate a bug in the test suite)."
380+
) from e

array_api_tests/dtype_helpers.py

Lines changed: 75 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from . import _array_module as xp
22

3+
34
__all__ = [
4-
"dtype_mapping",
55
"promotion_table",
66
"dtype_nbits",
77
"dtype_signed",
@@ -14,21 +14,6 @@
1414
"operators_to_functions",
1515
]
1616

17-
dtype_mapping = {
18-
'int8': xp.int8,
19-
'int16': xp.int16,
20-
'int32': xp.int32,
21-
'int64': xp.int64,
22-
'uint8': xp.uint8,
23-
'uint16': xp.uint16,
24-
'uint32': xp.uint32,
25-
'uint64': xp.uint64,
26-
'float32': xp.float32,
27-
'float64': xp.float64,
28-
'bool': xp.bool,
29-
}
30-
31-
reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()}
3217

3318
def dtype_nbits(dtype):
3419
if dtype == xp.int8:
@@ -54,79 +39,87 @@ def dtype_nbits(dtype):
5439
else:
5540
raise ValueError(f"dtype_nbits is not defined for {dtype}")
5641

42+
5743
def dtype_signed(dtype):
5844
if dtype in [xp.int8, xp.int16, xp.int32, xp.int64]:
5945
return True
6046
elif dtype in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]:
6147
return False
6248
raise ValueError("dtype_signed is only defined for integer dtypes")
6349

50+
6451
signed_integer_promotion_table = {
65-
('int8', 'int8'): 'int8',
66-
('int8', 'int16'): 'int16',
67-
('int8', 'int32'): 'int32',
68-
('int8', 'int64'): 'int64',
69-
('int16', 'int8'): 'int16',
70-
('int16', 'int16'): 'int16',
71-
('int16', 'int32'): 'int32',
72-
('int16', 'int64'): 'int64',
73-
('int32', 'int8'): 'int32',
74-
('int32', 'int16'): 'int32',
75-
('int32', 'int32'): 'int32',
76-
('int32', 'int64'): 'int64',
77-
('int64', 'int8'): 'int64',
78-
('int64', 'int16'): 'int64',
79-
('int64', 'int32'): 'int64',
80-
('int64', 'int64'): 'int64',
52+
(xp.int8, xp.int8): xp.int8,
53+
(xp.int8, xp.int16): xp.int16,
54+
(xp.int8, xp.int32): xp.int32,
55+
(xp.int8, xp.int64): xp.int64,
56+
(xp.int16, xp.int8): xp.int16,
57+
(xp.int16, xp.int16): xp.int16,
58+
(xp.int16, xp.int32): xp.int32,
59+
(xp.int16, xp.int64): xp.int64,
60+
(xp.int32, xp.int8): xp.int32,
61+
(xp.int32, xp.int16): xp.int32,
62+
(xp.int32, xp.int32): xp.int32,
63+
(xp.int32, xp.int64): xp.int64,
64+
(xp.int64, xp.int8): xp.int64,
65+
(xp.int64, xp.int16): xp.int64,
66+
(xp.int64, xp.int32): xp.int64,
67+
(xp.int64, xp.int64): xp.int64,
8168
}
8269

70+
8371
unsigned_integer_promotion_table = {
84-
('uint8', 'uint8'): 'uint8',
85-
('uint8', 'uint16'): 'uint16',
86-
('uint8', 'uint32'): 'uint32',
87-
('uint8', 'uint64'): 'uint64',
88-
('uint16', 'uint8'): 'uint16',
89-
('uint16', 'uint16'): 'uint16',
90-
('uint16', 'uint32'): 'uint32',
91-
('uint16', 'uint64'): 'uint64',
92-
('uint32', 'uint8'): 'uint32',
93-
('uint32', 'uint16'): 'uint32',
94-
('uint32', 'uint32'): 'uint32',
95-
('uint32', 'uint64'): 'uint64',
96-
('uint64', 'uint8'): 'uint64',
97-
('uint64', 'uint16'): 'uint64',
98-
('uint64', 'uint32'): 'uint64',
99-
('uint64', 'uint64'): 'uint64',
72+
(xp.uint8, xp.uint8): xp.uint8,
73+
(xp.uint8, xp.uint16): xp.uint16,
74+
(xp.uint8, xp.uint32): xp.uint32,
75+
(xp.uint8, xp.uint64): xp.uint64,
76+
(xp.uint16, xp.uint8): xp.uint16,
77+
(xp.uint16, xp.uint16): xp.uint16,
78+
(xp.uint16, xp.uint32): xp.uint32,
79+
(xp.uint16, xp.uint64): xp.uint64,
80+
(xp.uint32, xp.uint8): xp.uint32,
81+
(xp.uint32, xp.uint16): xp.uint32,
82+
(xp.uint32, xp.uint32): xp.uint32,
83+
(xp.uint32, xp.uint64): xp.uint64,
84+
(xp.uint64, xp.uint8): xp.uint64,
85+
(xp.uint64, xp.uint16): xp.uint64,
86+
(xp.uint64, xp.uint32): xp.uint64,
87+
(xp.uint64, xp.uint64): xp.uint64,
10088
}
10189

90+
10291
mixed_signed_unsigned_promotion_table = {
103-
('int8', 'uint8'): 'int16',
104-
('int8', 'uint16'): 'int32',
105-
('int8', 'uint32'): 'int64',
106-
('int16', 'uint8'): 'int16',
107-
('int16', 'uint16'): 'int32',
108-
('int16', 'uint32'): 'int64',
109-
('int32', 'uint8'): 'int32',
110-
('int32', 'uint16'): 'int32',
111-
('int32', 'uint32'): 'int64',
112-
('int64', 'uint8'): 'int64',
113-
('int64', 'uint16'): 'int64',
114-
('int64', 'uint32'): 'int64',
92+
(xp.int8, xp.uint8): xp.int16,
93+
(xp.int8, xp.uint16): xp.int32,
94+
(xp.int8, xp.uint32): xp.int64,
95+
(xp.int16, xp.uint8): xp.int16,
96+
(xp.int16, xp.uint16): xp.int32,
97+
(xp.int16, xp.uint32): xp.int64,
98+
(xp.int32, xp.uint8): xp.int32,
99+
(xp.int32, xp.uint16): xp.int32,
100+
(xp.int32, xp.uint32): xp.int64,
101+
(xp.int64, xp.uint8): xp.int64,
102+
(xp.int64, xp.uint16): xp.int64,
103+
(xp.int64, xp.uint32): xp.int64,
115104
}
116105

106+
117107
flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()}
118108

109+
119110
float_promotion_table = {
120-
('float32', 'float32'): 'float32',
121-
('float32', 'float64'): 'float64',
122-
('float64', 'float32'): 'float64',
123-
('float64', 'float64'): 'float64',
111+
(xp.float32, xp.float32): xp.float32,
112+
(xp.float32, xp.float64): xp.float64,
113+
(xp.float64, xp.float32): xp.float64,
114+
(xp.float64, xp.float64): xp.float64,
124115
}
125116

117+
126118
boolean_promotion_table = {
127-
('bool', 'bool'): 'bool',
119+
(xp.bool, xp.bool): xp.bool,
128120
}
129121

122+
130123
promotion_table = {
131124
**signed_integer_promotion_table,
132125
**unsigned_integer_promotion_table,
@@ -136,6 +129,7 @@ def dtype_signed(dtype):
136129
**boolean_promotion_table,
137130
}
138131

132+
139133
input_types = {
140134
'any': sorted(set(promotion_table.values())),
141135
'boolean': sorted(set(boolean_promotion_table.values())),
@@ -150,21 +144,23 @@ def dtype_signed(dtype):
150144
**unsigned_integer_promotion_table}.values())),
151145
}
152146

147+
153148
dtypes_to_scalars = {
154-
'bool': [bool],
155-
'int8': [int],
156-
'int16': [int],
157-
'int32': [int],
158-
'int64': [int],
149+
xp.bool: [bool],
150+
xp.int8: [int],
151+
xp.int16: [int],
152+
xp.int32: [int],
153+
xp.int64: [int],
159154
# Note: unsigned int dtypes only correspond to positive integers
160-
'uint8': [int],
161-
'uint16': [int],
162-
'uint32': [int],
163-
'uint64': [int],
164-
'float32': [int, float],
165-
'float64': [int, float],
155+
xp.uint8: [int],
156+
xp.uint16: [int],
157+
xp.uint32: [int],
158+
xp.uint64: [int],
159+
xp.float32: [int, float],
160+
xp.float64: [int, float],
166161
}
167162

163+
168164
elementwise_function_input_types = {
169165
'abs': 'numeric',
170166
'acos': 'floating',
@@ -224,6 +220,7 @@ def dtype_signed(dtype):
224220
'trunc': 'numeric',
225221
}
226222

223+
227224
elementwise_function_output_types = {
228225
'abs': 'promoted',
229226
'acos': 'promoted',
@@ -283,6 +280,7 @@ def dtype_signed(dtype):
283280
'trunc': 'promoted',
284281
}
285282

283+
286284
binary_operators = {
287285
'__add__': '+',
288286
'__and__': '&',
@@ -305,6 +303,7 @@ def dtype_signed(dtype):
305303
'__xor__': '^',
306304
}
307305

306+
308307
unary_operators = {
309308
'__abs__': 'abs()',
310309
'__invert__': '~',

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
boolean_dtype_objects,
1515
integer_or_boolean_dtype_objects, dtype_objects)
1616
from ._array_module import full, float32, float64, bool as bool_dtype, _UndefinedStub
17-
from .dtype_helpers import dtype_mapping, promotion_table
17+
from .dtype_helpers import promotion_table
1818
from . import _array_module
1919
from . import _array_module as xp
2020

@@ -54,9 +54,8 @@
5454
sorted_table = sorted(
5555
sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij)
5656
)
57-
dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table]
5857
if FILTER_UNDEFINED_DTYPES:
59-
dtype_pairs = [(i, j) for i, j in dtype_pairs
58+
sorted_table = [(i, j) for i, j in sorted_table
6059
if not isinstance(i, _UndefinedStub)
6160
and not isinstance(j, _UndefinedStub)]
6261

@@ -70,7 +69,7 @@ def mutually_promotable_dtypes(dtype_objects=dtype_objects):
7069
# pairs (XXX: Can we redesign the strategies so that they can prefer
7170
# shrinking dtypes over values?)
7271
return sampled_from(
73-
[(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects]
72+
[(i, j) for i, j in sorted_table if i in dtype_objects and j in dtype_objects]
7473
)
7574

7675
shared_mutually_promotable_dtype_pairs = shared(

array_api_tests/test_broadcasting.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES
1111
from .pytest_helpers import raises, doesnt_raise, nargs
1212

13-
from .dtype_helpers import (elementwise_function_input_types,
14-
input_types, dtype_mapping)
13+
from .dtype_helpers import elementwise_function_input_types, input_types
1514
from .function_stubs import elementwise_functions
1615
from . import _array_module
1716
from ._array_module import ones, _UndefinedStub
@@ -111,14 +110,14 @@ def test_broadcast_shapes_explicit_spec():
111110
@pytest.mark.parametrize('func_name', [i for i in
112111
elementwise_functions.__all__ if
113112
nargs(i) > 1])
114-
@given(shape1=shapes, shape2=shapes, dtype=data())
115-
def test_broadcasting_hypothesis(func_name, shape1, shape2, dtype):
113+
@given(shape1=shapes, shape2=shapes, data=data())
114+
def test_broadcasting_hypothesis(func_name, shape1, shape2, data):
116115
# Internal consistency checks
117116
assert nargs(func_name) == 2
118117

119-
dtype = dtype_mapping[dtype.draw(sampled_from(input_types[elementwise_function_input_types[func_name]]))]
120-
if FILTER_UNDEFINED_DTYPES and isinstance(dtype, _UndefinedStub):
121-
assume(False)
118+
dtype = data.draw(sampled_from(input_types[elementwise_function_input_types[func_name]]))
119+
if FILTER_UNDEFINED_DTYPES:
120+
assume(not isinstance(dtype, _UndefinedStub))
122121
func = getattr(_array_module, func_name)
123122

124123
if isinstance(func, _array_module._UndefinedStub):

0 commit comments

Comments
 (0)