Skip to content

Commit 2bee4e6

Browse files
committed
Add dpnp.broadcast_shapes implementation
1 parent 687b8ea commit 2bee4e6

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"atleast_2d",
6464
"atleast_3d",
6565
"broadcast_arrays",
66+
"broadcast_shapes",
6667
"broadcast_to",
6768
"can_cast",
6869
"column_stack",
@@ -967,6 +968,41 @@ def broadcast_arrays(*args, subok=False):
967968
return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays]
968969

969970

971+
def broadcast_shapes(*args):
972+
"""
973+
Broadcast the input shapes into a single shape.
974+
975+
For full documentation refer to :obj:`numpy.broadcast_shapes`.
976+
977+
Parameters
978+
----------
979+
*args : tuples of ints, or ints
980+
The shapes to be broadcast against each other.
981+
982+
Returns
983+
-------
984+
tuple
985+
Broadcasted shape.
986+
987+
See Also
988+
--------
989+
:obj:`dpnp.broadcast_arrays` : Broadcast any number of arrays against
990+
each other.
991+
:obj:`dpnp.broadcast_to` : Broadcast an array to a new shape.
992+
993+
Examples
994+
--------
995+
>>> import dpnp as np
996+
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
997+
(3, 2)
998+
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
999+
(5, 6, 7)
1000+
1001+
"""
1002+
1003+
return numpy.broadcast_shapes(*args)
1004+
1005+
9701006
# pylint: disable=redefined-outer-name
9711007
def broadcast_to(array, /, shape, subok=False):
9721008
"""

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
994994
a_shape = a.shape
995995
b_shape = b.shape
996996

997-
# TODO: replace with dpnp.broadcast_shapes once implemented
998-
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
997+
res_shape = dpnp.broadcast_shapes(a_shape[:-1], b_shape[:-1])
999998
if a_shape[:-1] != res_shape:
1000999
a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],))
10011000
a_shape = a.shape

tests/test_manipulation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,31 @@ def test_no_copy(self):
332332
assert_array_equal(b, a)
333333

334334

335+
class TestBroadcast:
336+
@pytest.mark.parametrize(
337+
"shape",
338+
[
339+
[(1,), (3,)],
340+
[(1, 3), (3, 3)],
341+
[(3, 1), (3, 3)],
342+
[(1, 3), (3, 1)],
343+
[(1, 1), (3, 3)],
344+
[(1, 1), (1, 3)],
345+
[(1, 1), (3, 1)],
346+
[(1, 0), (0, 0)],
347+
[(0, 1), (0, 0)],
348+
[(1, 0), (0, 1)],
349+
[(1, 1), (0, 0)],
350+
[(1, 1), (1, 0)],
351+
[(1, 1), (0, 1)],
352+
],
353+
)
354+
def test_broadcast_shapes(self, shape):
355+
expected = numpy.broadcast_shapes(*shape)
356+
result = dpnp.broadcast_shapes(*shape)
357+
assert_equal(result, expected)
358+
359+
335360
class TestDelete:
336361
@pytest.mark.parametrize(
337362
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]

0 commit comments

Comments
 (0)