Skip to content

26302 add typing to assert star equal funcs #29364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 88 additions & 55 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shutil import rmtree
import string
import tempfile
from typing import Union, cast
from typing import Optional, Union, cast
import warnings
import zipfile

Expand Down Expand Up @@ -53,6 +53,7 @@
Series,
bdate_range,
)
from pandas._typing import AnyArrayLike
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AnyArrayLike resolves to Any.

If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.

from pandas.core.algorithms import take_1d
from pandas.core.arrays import (
DatetimeArray,
Expand Down Expand Up @@ -806,8 +807,12 @@ def assert_is_sorted(seq):


def assert_categorical_equal(
left, right, check_dtype=True, check_category_order=True, obj="Categorical"
):
left: Categorical,
right: Categorical,
check_dtype: bool = True,
check_category_order: bool = True,
obj: str = "Categorical",
) -> None:
"""Test that Categoricals are equivalent.

Parameters
Expand Down Expand Up @@ -852,7 +857,12 @@ def assert_categorical_equal(
assert_attr_equal("ordered", left, right, obj=obj)


def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
def assert_interval_array_equal(
left: IntervalArray,
right: IntervalArray,
exact: str = "equiv",
obj: str = "IntervalArray",
) -> None:
"""Test that two IntervalArrays are equivalent.

Parameters
Expand All @@ -878,7 +888,9 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
assert_attr_equal("closed", left, right, obj=obj)


def assert_period_array_equal(left, right, obj="PeriodArray"):
def assert_period_array_equal(
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray"
) -> None:
_check_isinstance(left, right, PeriodArray)

assert_numpy_array_equal(
Expand All @@ -887,7 +899,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
assert_attr_equal("freq", left, right, obj=obj)


def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
def assert_datetime_array_equal(
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, DatetimeArray)

Expand All @@ -896,7 +910,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
assert_attr_equal("tz", left, right, obj=obj)


def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
def assert_timedelta_array_equal(
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, TimedeltaArray)
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj))
Expand Down Expand Up @@ -931,14 +947,14 @@ def raise_assert_detail(obj, message, left, right, diff=None):


def assert_numpy_array_equal(
left,
right,
strict_nan=False,
check_dtype=True,
err_msg=None,
check_same=None,
obj="numpy array",
):
left: np.ndarray,
right: np.ndarray,
strict_nan: bool = False,
check_dtype: bool = True,
err_msg: Optional[str] = None,
check_same: Optional[str] = None,
obj: str = "numpy array",
) -> None:
""" Checks that 'np.ndarray' is equivalent

Parameters
Expand Down Expand Up @@ -1067,18 +1083,18 @@ def assert_extension_array_equal(

# This could be refactored to use the NDFrame.equals method
def assert_series_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_series_type=True,
check_less_precise=False,
check_names=True,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
obj="Series",
):
left: Series,
right: Series,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_series_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
obj: str = "Series",
) -> None:
"""
Check that left and right Series are equal.

Expand Down Expand Up @@ -1185,8 +1201,11 @@ def assert_series_equal(
right._internal_get_values(),
check_dtype=check_dtype,
)
elif is_interval_dtype(left) or is_interval_dtype(right):
assert_interval_array_equal(left.array, right.array)
elif is_interval_dtype(left) or is_interval_dtype(left):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now just elif is_interval_dtype(left) or should the second condition not have been changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think change was my mistake. Have removed it.

# must cast to interval dtype to keep mypy happy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the complaint on this? This changes the actual assertions being done I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't cast to IntervalArray I get these errors

pandas/util/testing.py:1211: error: Argument 1 to "assert_interval_array_equal" has incompatible type "ExtensionArray"; expected "IntervalArray"
pandas/util/testing.py:1211: error: Argument 2 to "assert_interval_array_equal" has incompatible type "ExtensionArray"; expected "IntervalArray"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the asserts though, they weren't needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @jreback should use cast from the typing module here - don't want to actually construct a new object via a call to IntervalArray

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete this comment

left_array = IntervalArray(left.array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we don't want to do this (cast is ok), but don't actually coerce with a constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha got it, wasn't aware of typing.cast. Have updated the code.

right_array = IntervalArray(right.array)
assert_interval_array_equal(left_array, right_array)
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype):
# .values is an ndarray, but ._values is the ExtensionArray.
# TODO: Use .array
Expand Down Expand Up @@ -1221,21 +1240,21 @@ def assert_series_equal(

# This could be refactored to use the NDFrame.equals method
def assert_frame_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_column_type="equiv",
check_frame_type=True,
check_less_precise=False,
check_names=True,
by_blocks=False,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
check_like=False,
obj="DataFrame",
):
left: DataFrame,
right: DataFrame,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_column_type: str = "equiv",
check_frame_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
by_blocks: bool = False,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_like: bool = False,
obj: str = "DataFrame",
) -> None:
"""
Check that left and right DataFrame are equal.

Expand Down Expand Up @@ -1403,7 +1422,11 @@ def assert_frame_equal(
)


def assert_equal(left, right, **kwargs):
def assert_equal(
left: Union[DataFrame, AnyArrayLike],
right: Union[DataFrame, AnyArrayLike],
**kwargs
) -> None:
"""
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.

Expand All @@ -1415,27 +1438,37 @@ def assert_equal(left, right, **kwargs):
"""
__tracebackhide__ = True

if isinstance(left, pd.Index):
if isinstance(left, Index):
_check_isinstance(left, right, Index)
right = Index(right)
assert_index_equal(left, right, **kwargs)
elif isinstance(left, pd.Series):
elif isinstance(left, Series):
assert isinstance(right, Series)
assert_series_equal(left, right, **kwargs)
elif isinstance(left, pd.DataFrame):
elif isinstance(left, DataFrame):
assert isinstance(right, DataFrame)
assert_frame_equal(left, right, **kwargs)
elif isinstance(left, IntervalArray):
assert isinstance(right, IntervalArray)
assert_interval_array_equal(left, right, **kwargs)
elif isinstance(left, PeriodArray):
assert isinstance(right, PeriodArray)
assert_period_array_equal(left, right, **kwargs)
elif isinstance(left, DatetimeArray):
assert isinstance(right, DatetimeArray)
assert_datetime_array_equal(left, right, **kwargs)
elif isinstance(left, TimedeltaArray):
assert isinstance(right, TimedeltaArray)
assert_timedelta_array_equal(left, right, **kwargs)
elif isinstance(left, ExtensionArray):
assert isinstance(right, ExtensionArray)
assert_extension_array_equal(left, right, **kwargs)
elif isinstance(left, np.ndarray):
assert isinstance(right, np.ndarray)
assert_numpy_array_equal(left, right, **kwargs)
elif isinstance(left, str):
assert kwargs == {}
return left == right
assert left == right
else:
raise NotImplementedError(type(left))

Expand Down Expand Up @@ -1497,12 +1530,12 @@ def to_array(obj):


def assert_sp_array_equal(
left,
right,
check_dtype=True,
check_kind=True,
check_fill_value=True,
consolidate_block_indices=False,
left: pd.SparseArray,
right: pd.SparseArray,
check_dtype: bool = True,
check_kind: bool = True,
check_fill_value: bool = True,
consolidate_block_indices: bool = False,
):
"""Check that the left and right SparseArray are equal.

Expand Down