Skip to content

Commit 0b3efb1

Browse files
MarcoGorelliMaanasArora
authored andcommitted
TYP: Type MaskedArray.repeat, improve overloads for NDArray.repeat, generic.repeat, and np.repeat (numpy#28849)
1 parent 14ea82d commit 0b3efb1

File tree

6 files changed

+55
-13
lines changed

6 files changed

+55
-13
lines changed

numpy/__init__.pyi

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,10 +2421,17 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
24212421
mode: _ModeKind = ...,
24222422
) -> _ArrayT: ...
24232423

2424+
@overload
24242425
def repeat(
24252426
self,
24262427
repeats: _ArrayLikeInt_co,
2427-
axis: SupportsIndex | None = ...,
2428+
axis: None = None,
2429+
) -> ndarray[tuple[int], _DTypeT_co]: ...
2430+
@overload
2431+
def repeat(
2432+
self,
2433+
repeats: _ArrayLikeInt_co,
2434+
axis: SupportsIndex,
24282435
) -> ndarray[_Shape, _DTypeT_co]: ...
24292436

24302437
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DTypeT_co]: ...
@@ -3685,7 +3692,7 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
36853692
mode: _ModeKind = ...,
36863693
) -> _ArrayT: ...
36873694

3688-
def repeat(self, repeats: _ArrayLikeInt_co, axis: SupportsIndex | None = ...) -> NDArray[Self]: ...
3695+
def repeat(self, repeats: _ArrayLikeInt_co, axis: SupportsIndex | None = None) -> ndarray[tuple[int], dtype[Self]]: ...
36893696
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
36903697
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
36913698

numpy/_core/fromnumeric.pyi

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,25 @@ def choose(
277277
def repeat(
278278
a: _ArrayLike[_ScalarT],
279279
repeats: _ArrayLikeInt_co,
280-
axis: SupportsIndex | None = ...,
280+
axis: None = None,
281+
) -> np.ndarray[tuple[int], np.dtype[_ScalarT]]: ...
282+
@overload
283+
def repeat(
284+
a: _ArrayLike[_ScalarT],
285+
repeats: _ArrayLikeInt_co,
286+
axis: SupportsIndex,
281287
) -> NDArray[_ScalarT]: ...
282288
@overload
283289
def repeat(
284290
a: ArrayLike,
285291
repeats: _ArrayLikeInt_co,
286-
axis: SupportsIndex | None = ...,
292+
axis: None = None,
293+
) -> np.ndarray[tuple[int], np.dtype[Any]]: ...
294+
@overload
295+
def repeat(
296+
a: ArrayLike,
297+
repeats: _ArrayLikeInt_co,
298+
axis: SupportsIndex,
287299
) -> NDArray[Any]: ...
288300

289301
def put(

numpy/ma/core.pyi

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,20 @@ class MaskedArray(ndarray[_ShapeT_co, _DTypeT_co]):
778778
copy: Any
779779
diagonal: Any
780780
flatten: Any
781-
repeat: Any
781+
782+
@overload
783+
def repeat(
784+
self,
785+
repeats: _ArrayLikeInt_co,
786+
axis: None = None,
787+
) -> MaskedArray[tuple[int], _DTypeT_co]: ...
788+
@overload
789+
def repeat(
790+
self,
791+
repeats: _ArrayLikeInt_co,
792+
axis: SupportsIndex,
793+
) -> MaskedArray[_Shape, _DTypeT_co]: ...
794+
782795
squeeze: Any
783796
swapaxes: Any
784797
T: Any

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ assert_type(np.choose([1], [True, True]), npt.NDArray[Any])
4646
assert_type(np.choose([1], AR_b), npt.NDArray[np.bool])
4747
assert_type(np.choose([1], AR_b, out=AR_f4), npt.NDArray[np.float32])
4848

49-
assert_type(np.repeat(b, 1), npt.NDArray[np.bool])
50-
assert_type(np.repeat(f4, 1), npt.NDArray[np.float32])
51-
assert_type(np.repeat(f, 1), npt.NDArray[Any])
52-
assert_type(np.repeat(AR_b, 1), npt.NDArray[np.bool])
53-
assert_type(np.repeat(AR_f4, 1), npt.NDArray[np.float32])
49+
assert_type(np.repeat(b, 1), np.ndarray[tuple[int], np.dtype[np.bool]])
50+
assert_type(np.repeat(b, 1, axis=0), npt.NDArray[np.bool])
51+
assert_type(np.repeat(f4, 1), np.ndarray[tuple[int], np.dtype[np.float32]])
52+
assert_type(np.repeat(f, 1), np.ndarray[tuple[int], np.dtype[Any]])
53+
assert_type(np.repeat(AR_b, 1), np.ndarray[tuple[int], np.dtype[np.bool]])
54+
assert_type(np.repeat(AR_f4, 1), np.ndarray[tuple[int], np.dtype[np.float32]])
55+
assert_type(np.repeat(AR_f4, 1, axis=0), npt.NDArray[np.float32])
5456

5557
# TODO: array_bdd tests for np.put()
5658

numpy/typing/tests/data/reveal/ma.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ assert_type(np.ma.filled([[1,2,3]]), NDArray[Any])
276276
# https://github.com/numpy/numpy/pull/28742#discussion_r2048968375
277277
assert_type(np.ma.filled(MAR_1d), np.ndarray[tuple[int], np.dtype]) # type: ignore[assert-type]
278278

279+
assert_type(MAR_b.repeat(3), np.ma.MaskedArray[tuple[int], np.dtype[np.bool]])
280+
assert_type(MAR_2d_f4.repeat(MAR_i8), np.ma.MaskedArray[tuple[int], np.dtype[np.float32]])
281+
assert_type(MAR_2d_f4.repeat(MAR_i8, axis=None), np.ma.MaskedArray[tuple[int], np.dtype[np.float32]])
282+
assert_type(MAR_2d_f4.repeat(MAR_i8, axis=0), MaskedNDArray[np.float32])
283+
279284
assert_type(np.ma.allequal(AR_f4, MAR_f4), bool)
280285
assert_type(np.ma.allequal(AR_f4, MAR_f4, fill_value=False), bool)
281286

numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,12 @@ assert_type(f8.round(), np.float64)
126126
assert_type(AR_f8.round(), npt.NDArray[np.float64])
127127
assert_type(AR_f8.round(out=B), SubClass)
128128

129-
assert_type(f8.repeat(1), npt.NDArray[np.float64])
130-
assert_type(AR_f8.repeat(1), npt.NDArray[np.float64])
131-
assert_type(B.repeat(1), npt.NDArray[np.object_])
129+
assert_type(f8.repeat(1), np.ndarray[tuple[int], np.dtype[np.float64]])
130+
assert_type(f8.repeat(1, axis=0), np.ndarray[tuple[int], np.dtype[np.float64]])
131+
assert_type(AR_f8.repeat(1), np.ndarray[tuple[int], np.dtype[np.float64]])
132+
assert_type(AR_f8.repeat(1, axis=0), npt.NDArray[np.float64])
133+
assert_type(B.repeat(1), np.ndarray[tuple[int], np.dtype[np.object_]])
134+
assert_type(B.repeat(1, axis=0), npt.NDArray[np.object_])
132135

133136
assert_type(f8.std(), Any)
134137
assert_type(AR_f8.std(), Any)

0 commit comments

Comments
 (0)