Skip to content

Commit 48fd544

Browse files
committed
Convert dtype_signed and dtype_nbits to dicts
1 parent e4461b3 commit 48fd544

File tree

3 files changed

+21
-40
lines changed

3 files changed

+21
-40
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,18 @@
1515
]
1616

1717

18-
def dtype_nbits(dtype):
19-
if dtype == xp.int8:
20-
return 8
21-
elif dtype == xp.int16:
22-
return 16
23-
elif dtype == xp.int32:
24-
return 32
25-
elif dtype == xp.int64:
26-
return 64
27-
elif dtype == xp.uint8:
28-
return 8
29-
elif dtype == xp.uint16:
30-
return 16
31-
elif dtype == xp.uint32:
32-
return 32
33-
elif dtype == xp.uint64:
34-
return 64
35-
elif dtype == xp.float32:
36-
return 32
37-
elif dtype == xp.float64:
38-
return 64
39-
else:
40-
raise ValueError(f"dtype_nbits is not defined for {dtype}")
41-
42-
43-
def dtype_signed(dtype):
44-
if dtype in [xp.int8, xp.int16, xp.int32, xp.int64]:
45-
return True
46-
elif dtype in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]:
47-
return False
48-
raise ValueError("dtype_signed is only defined for integer dtypes")
18+
dtype_nbits = {
19+
**{d: 8 for d in [xp.int8, xp.uint8]},
20+
**{d: 16 for d in [xp.int16, xp.uint16]},
21+
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
22+
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
23+
}
24+
25+
26+
dtype_signed = {
27+
**{d: True for d in [xp.int8, xp.int16, xp.int32, xp.int64]},
28+
**{d: False for d in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]},
29+
}
4930

5031

5132
signed_integer_promotion_table = {

array_api_tests/meta_tests/test_array_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_notequal():
2525

2626
@given(integers(), integer_dtypes)
2727
def test_int_to_dtype(x, dtype):
28-
n = dtype_nbits(dtype)
29-
signed = dtype_signed(dtype)
28+
n = dtype_nbits[dtype]
29+
signed = dtype_signed[dtype]
3030
try:
3131
d = xp.asarray(x, dtype=dtype)
3232
except OverflowError:

array_api_tests/test_elementwise_functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_bitwise_and(args):
223223
x = int(x1)
224224
y = int(x2)
225225
res = int(a)
226-
ans = int_to_dtype(x & y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
226+
ans = int_to_dtype(x & y, dtype_nbits[a.dtype], dtype_signed[a.dtype])
227227
assert ans == res
228228

229229
@given(two_integer_dtypes.flatmap(lambda i: two_array_scalars(*i)))
@@ -240,12 +240,12 @@ def test_bitwise_left_shift(args):
240240
raise RuntimeError("Error: test_bitwise_left_shift needs to be updated for nonscalar array inputs")
241241
x = int(x1)
242242
y = int(x2)
243-
if y >= dtype_nbits(a.dtype):
243+
if y >= dtype_nbits[a.dtype]:
244244
# Avoid shifting very large y in Python ints
245245
ans = 0
246246
else:
247247
ans = x << y
248-
ans = int_to_dtype(ans, dtype_nbits(a.dtype), dtype_signed(a.dtype))
248+
ans = int_to_dtype(ans, dtype_nbits[a.dtype], dtype_signed[a.dtype])
249249
res = int(a)
250250
assert ans == res
251251

@@ -263,7 +263,7 @@ def test_bitwise_invert(x):
263263
else:
264264
x = int(x)
265265
res = int(a)
266-
ans = int_to_dtype(~x, dtype_nbits(a.dtype), dtype_signed(a.dtype))
266+
ans = int_to_dtype(~x, dtype_nbits[a.dtype], dtype_signed[a.dtype])
267267
assert ans == res
268268

269269
@given(two_integer_or_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i)))
@@ -284,7 +284,7 @@ def test_bitwise_or(args):
284284
x = int(x1)
285285
y = int(x2)
286286
res = int(a)
287-
ans = int_to_dtype(x | y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
287+
ans = int_to_dtype(x | y, dtype_nbits[a.dtype], dtype_signed[a.dtype])
288288
assert ans == res
289289

290290
@given(two_integer_dtypes.flatmap(lambda i: two_array_scalars(*i)))
@@ -301,7 +301,7 @@ def test_bitwise_right_shift(args):
301301
raise RuntimeError("Error: test_bitwise_right_shift needs to be updated for nonscalar array inputs")
302302
x = int(x1)
303303
y = int(x2)
304-
ans = int_to_dtype(x >> y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
304+
ans = int_to_dtype(x >> y, dtype_nbits[a.dtype], dtype_signed[a.dtype])
305305
res = int(a)
306306
assert ans == res
307307

@@ -323,7 +323,7 @@ def test_bitwise_xor(args):
323323
x = int(x1)
324324
y = int(x2)
325325
res = int(a)
326-
ans = int_to_dtype(x ^ y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
326+
ans = int_to_dtype(x ^ y, dtype_nbits[a.dtype], dtype_signed[a.dtype])
327327
assert ans == res
328328

329329
@given(numeric_scalars)

0 commit comments

Comments
 (0)