Skip to content

Commit 3a22dd9

Browse files
Merge pull request #1136 from IntelPython/slicing-bug-gh-1135
Slicing bug gh 1135
2 parents 18bb612 + 8c48886 commit 3a22dd9

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

dpctl/tensor/_slicing.pxi

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -15,6 +15,11 @@
1515
# limitations under the License.
1616

1717
import numbers
18+
from cpython.buffer cimport PyObject_CheckBuffer
19+
20+
21+
cdef bint _is_buffer(object o):
22+
return PyObject_CheckBuffer(o)
1823

1924

2025
cdef Py_ssize_t _slice_len(
@@ -36,14 +41,23 @@ cdef Py_ssize_t _slice_len(
3641

3742
cdef bint _is_integral(object x) except *:
3843
"""Gives True if x is an integral slice spec"""
39-
if isinstance(x, (int, numbers.Integral)):
40-
return True
4144
if isinstance(x, usm_ndarray):
4245
if x.ndim > 0:
4346
return False
4447
if x.dtype.kind not in "ui":
4548
return False
4649
return True
50+
if isinstance(x, bool):
51+
return False
52+
if isinstance(x, int):
53+
return True
54+
if _is_buffer(x):
55+
mbuf = memoryview(x)
56+
if mbuf.ndim == 0:
57+
f = mbuf.format
58+
return f in "bBhHiIlLqQ"
59+
else:
60+
return False
4761
if callable(getattr(x, "__index__", None)):
4862
try:
4963
x.__index__()
@@ -53,6 +67,34 @@ cdef bint _is_integral(object x) except *:
5367
return False
5468

5569

70+
cdef bint _is_boolean(object x) except *:
71+
"""Gives True if x is an integral slice spec"""
72+
if isinstance(x, usm_ndarray):
73+
if x.ndim > 0:
74+
return False
75+
if x.dtype.kind not in "b":
76+
return False
77+
return True
78+
if isinstance(x, bool):
79+
return True
80+
if isinstance(x, int):
81+
return False
82+
if _is_buffer(x):
83+
mbuf = memoryview(x)
84+
if mbuf.ndim == 0:
85+
f = mbuf.format
86+
return f in "?"
87+
else:
88+
return False
89+
if callable(getattr(x, "__bool__", None)):
90+
try:
91+
x.__bool__()
92+
except (TypeError, ValueError):
93+
return False
94+
return True
95+
return False
96+
97+
5698
def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
5799
"""
58100
Give basic slicing index `ind` and array layout information produce
@@ -82,6 +124,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
82124
_no_advanced_ind,
83125
_no_advanced_pos
84126
)
127+
elif _is_boolean(ind):
128+
if ind:
129+
return ((1,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
130+
else:
131+
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
85132
elif _is_integral(ind):
86133
ind = ind.__index__()
87134
if 0 <= ind < shape[0]:
@@ -117,6 +164,10 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
117164
axes_referenced += 1
118165
if array_streak_started:
119166
array_streak_interrupted = True
167+
elif _is_boolean(i):
168+
newaxis_count += 1
169+
if array_streak_started:
170+
array_streak_interrupted = True
120171
elif _is_integral(i):
121172
explicit_index += 1
122173
axes_referenced += 1
@@ -133,9 +184,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
133184
"separated by basic slicing specs."
134185
)
135186
dt_k = i.dtype.kind
136-
if dt_k == "b":
187+
if dt_k == "b" and i.ndim > 0:
137188
axes_referenced += i.ndim
138-
elif dt_k in "ui":
189+
elif dt_k in "ui" and i.ndim > 0:
139190
axes_referenced += 1
140191
else:
141192
raise IndexError(
@@ -186,6 +237,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
186237
if sh_i == 0:
187238
is_empty = True
188239
k = k_new
240+
elif _is_boolean(ind_i):
241+
new_shape.append(1 if ind_i else 0)
242+
new_strides.append(0)
189243
elif _is_integral(ind_i):
190244
ind_i = ind_i.__index__()
191245
if 0 <= ind_i < shape[k]:

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,32 @@ def test_integer_strided_indexing():
455455
assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all()
456456

457457

458+
def test_TrueFalse_indexing():
459+
get_queue_or_skip()
460+
n0, n1 = 2, 3
461+
x = dpt.ones((n0, n1))
462+
for ind in [True, dpt.asarray(True)]:
463+
y1 = x[ind]
464+
assert y1.shape == (1, n0, n1)
465+
assert y1._pointer == x._pointer
466+
y2 = x[:, ind]
467+
assert y2.shape == (n0, 1, n1)
468+
assert y2._pointer == x._pointer
469+
y3 = x[..., ind]
470+
assert y3.shape == (n0, n1, 1)
471+
assert y3._pointer == x._pointer
472+
for ind in [False, dpt.asarray(False)]:
473+
y1 = x[ind]
474+
assert y1.shape == (0, n0, n1)
475+
assert y1._pointer == x._pointer
476+
y2 = x[:, ind]
477+
assert y2.shape == (n0, 0, n1)
478+
assert y2._pointer == x._pointer
479+
y3 = x[..., ind]
480+
assert y3.shape == (n0, n1, 0)
481+
assert y3._pointer == x._pointer
482+
483+
458484
@pytest.mark.parametrize(
459485
"data_dt",
460486
_all_dtypes,

0 commit comments

Comments
 (0)