1
1
# Data Parallel Control (dpctl)
2
2
#
3
- # Copyright 2020-2022 Intel Corporation
3
+ # Copyright 2020-2023 Intel Corporation
4
4
#
5
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
6
# you may not use this file except in compliance with the License.
15
15
# limitations under the License.
16
16
17
17
import numbers
18
+ from cpython.buffer cimport PyObject_CheckBuffer
19
+
20
+
21
+ cdef bint _is_buffer(object o):
22
+ return PyObject_CheckBuffer(o)
18
23
19
24
20
25
cdef Py_ssize_t _slice_len(
@@ -36,14 +41,23 @@ cdef Py_ssize_t _slice_len(
36
41
37
42
cdef bint _is_integral(object x) except * :
38
43
""" Gives True if x is an integral slice spec"""
39
- if isinstance (x, (int , numbers.Integral)):
40
- return True
41
44
if isinstance (x, usm_ndarray):
42
45
if x.ndim > 0 :
43
46
return False
44
47
if x.dtype.kind not in " ui" :
45
48
return False
46
49
return True
50
+ if isinstance (x, bool ):
51
+ return False
52
+ if isinstance (x, int ):
53
+ return True
54
+ if _is_buffer(x):
55
+ mbuf = memoryview(x)
56
+ if mbuf.ndim == 0 :
57
+ f = mbuf.format
58
+ return f in " bBhHiIlLqQ"
59
+ else :
60
+ return False
47
61
if callable (getattr (x, " __index__" , None )):
48
62
try :
49
63
x.__index__()
@@ -53,6 +67,34 @@ cdef bint _is_integral(object x) except *:
53
67
return False
54
68
55
69
70
+ cdef bint _is_boolean(object x) except * :
71
+ """ Gives True if x is an integral slice spec"""
72
+ if isinstance (x, usm_ndarray):
73
+ if x.ndim > 0 :
74
+ return False
75
+ if x.dtype.kind not in " b" :
76
+ return False
77
+ return True
78
+ if isinstance (x, bool ):
79
+ return True
80
+ if isinstance (x, int ):
81
+ return False
82
+ if _is_buffer(x):
83
+ mbuf = memoryview(x)
84
+ if mbuf.ndim == 0 :
85
+ f = mbuf.format
86
+ return f in " ?"
87
+ else :
88
+ return False
89
+ if callable (getattr (x, " __bool__" , None )):
90
+ try :
91
+ x.__bool__()
92
+ except (TypeError , ValueError ):
93
+ return False
94
+ return True
95
+ return False
96
+
97
+
56
98
def _basic_slice_meta (ind , shape : tuple , strides : tuple , offset : int ):
57
99
"""
58
100
Give basic slicing index `ind` and array layout information produce
@@ -82,6 +124,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
82
124
_no_advanced_ind,
83
125
_no_advanced_pos
84
126
)
127
+ elif _is_boolean(ind):
128
+ if ind:
129
+ return ((1 ,) + shape, (0 ,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
130
+ else :
131
+ return ((0 ,) + shape, (0 ,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
85
132
elif _is_integral(ind):
86
133
ind = ind.__index__()
87
134
if 0 <= ind < shape[0 ]:
@@ -117,6 +164,10 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
117
164
axes_referenced += 1
118
165
if array_streak_started:
119
166
array_streak_interrupted = True
167
+ elif _is_boolean(i):
168
+ newaxis_count += 1
169
+ if array_streak_started:
170
+ array_streak_interrupted = True
120
171
elif _is_integral(i):
121
172
explicit_index += 1
122
173
axes_referenced += 1
@@ -133,9 +184,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
133
184
" separated by basic slicing specs."
134
185
)
135
186
dt_k = i.dtype.kind
136
- if dt_k == " b" :
187
+ if dt_k == " b" and i.ndim > 0 :
137
188
axes_referenced += i.ndim
138
- elif dt_k in " ui" :
189
+ elif dt_k in " ui" and i.ndim > 0 :
139
190
axes_referenced += 1
140
191
else :
141
192
raise IndexError (
@@ -186,6 +237,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
186
237
if sh_i == 0 :
187
238
is_empty = True
188
239
k = k_new
240
+ elif _is_boolean(ind_i):
241
+ new_shape.append(1 if ind_i else 0 )
242
+ new_strides.append(0 )
189
243
elif _is_integral(ind_i):
190
244
ind_i = ind_i.__index__()
191
245
if 0 <= ind_i < shape[k]:
0 commit comments