-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Refactor test_parquet.py to use check_round_trip at module level #19332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,48 +110,82 @@ def df_full(): | |
pd.Timestamp('20130103')]}) | ||
|
||
|
||
def test_invalid_engine(df_compat): | ||
def check_round_trip(df, engine=None, path=None, | ||
write_kwargs=None, read_kwargs=None, | ||
expected=None, check_names=True, | ||
repeat=2): | ||
""" | ||
Verify parquet serialize and deserialize produce the same results. | ||
|
||
Performs a pandas to disk and disk to pandas round trip, | ||
then compares the 2 resulting DataFrames to verify full | ||
cycle is successful. | ||
|
||
:param df: Dataframe to be serialized to disk | ||
:param engine: str one of ['pyarrow', 'fastparquet'] | ||
:param path: str | ||
:param write_kwargs: dict(str:str) params to be passed to the serialization | ||
engine. | ||
:param read_kwargs: dict(str:str) params to be passed to the | ||
deserialization engine. | ||
:param expected: DataFrame If provides deserialization will be | ||
compared againt it. | ||
:param check_names: list(str) specific columns to be compared | ||
:param repeat No. of times to repeat the test. | ||
""" | ||
|
||
if write_kwargs is None: | ||
write_kwargs = {'compression': None} | ||
|
||
if read_kwargs is None: | ||
read_kwargs = {} | ||
|
||
if engine: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a case where engine is not required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, mostly when testing for "defaults" handling (i.e. default engine, or engine being configured via |
||
write_kwargs['engine'] = engine | ||
read_kwargs['engine'] = engine | ||
|
||
if expected is None: | ||
expected = df | ||
|
||
def compare(): | ||
df.to_parquet(path, **write_kwargs) | ||
actual = read_parquet(path, **read_kwargs) | ||
tm.assert_frame_equal(expected, actual, | ||
check_names=check_names) | ||
|
||
if path is None: | ||
with tm.ensure_clean() as path: | ||
for _ in range(repeat): | ||
compare() | ||
else: | ||
for _ in range(repeat): | ||
compare() | ||
|
||
|
||
def test_invalid_engine(df_compat): | ||
with pytest.raises(ValueError): | ||
df_compat.to_parquet('foo', 'bar') | ||
check_round_trip(df_compat, 'foo', 'bar') | ||
|
||
|
||
def test_options_py(df_compat, pa): | ||
# use the set option | ||
|
||
df = df_compat | ||
with tm.ensure_clean() as path: | ||
|
||
with pd.option_context('io.parquet.engine', 'pyarrow'): | ||
df.to_parquet(path) | ||
|
||
result = read_parquet(path) | ||
tm.assert_frame_equal(result, df) | ||
with pd.option_context('io.parquet.engine', 'pyarrow'): | ||
check_round_trip(df_compat) | ||
|
||
|
||
def test_options_fp(df_compat, fp): | ||
# use the set option | ||
|
||
df = df_compat | ||
with tm.ensure_clean() as path: | ||
|
||
with pd.option_context('io.parquet.engine', 'fastparquet'): | ||
df.to_parquet(path, compression=None) | ||
|
||
result = read_parquet(path) | ||
tm.assert_frame_equal(result, df) | ||
with pd.option_context('io.parquet.engine', 'fastparquet'): | ||
check_round_trip(df_compat) | ||
|
||
|
||
def test_options_auto(df_compat, fp, pa): | ||
# use the set option | ||
|
||
df = df_compat | ||
with tm.ensure_clean() as path: | ||
|
||
with pd.option_context('io.parquet.engine', 'auto'): | ||
df.to_parquet(path) | ||
|
||
result = read_parquet(path) | ||
tm.assert_frame_equal(result, df) | ||
with pd.option_context('io.parquet.engine', 'auto'): | ||
check_round_trip(df_compat) | ||
|
||
|
||
def test_options_get_engine(fp, pa): | ||
|
@@ -228,53 +262,23 @@ def check_error_on_write(self, df, engine, exc): | |
with tm.ensure_clean() as path: | ||
to_parquet(df, path, engine, compression=None) | ||
|
||
def check_round_trip(self, df, engine, expected=None, path=None, | ||
write_kwargs=None, read_kwargs=None, | ||
check_names=True): | ||
|
||
if write_kwargs is None: | ||
write_kwargs = {'compression': None} | ||
|
||
if read_kwargs is None: | ||
read_kwargs = {} | ||
|
||
if expected is None: | ||
expected = df | ||
|
||
if path is None: | ||
with tm.ensure_clean() as path: | ||
check_round_trip_equals(df, path, engine, | ||
write_kwargs=write_kwargs, | ||
read_kwargs=read_kwargs, | ||
expected=expected, | ||
check_names=check_names) | ||
else: | ||
check_round_trip_equals(df, path, engine, | ||
write_kwargs=write_kwargs, | ||
read_kwargs=read_kwargs, | ||
expected=expected, | ||
check_names=check_names) | ||
|
||
|
||
class TestBasic(Base): | ||
|
||
def test_error(self, engine): | ||
|
||
for obj in [pd.Series([1, 2, 3]), 1, 'foo', pd.Timestamp('20130101'), | ||
np.array([1, 2, 3])]: | ||
self.check_error_on_write(obj, engine, ValueError) | ||
|
||
def test_columns_dtypes(self, engine): | ||
|
||
df = pd.DataFrame({'string': list('abc'), | ||
'int': list(range(1, 4))}) | ||
|
||
# unicode | ||
df.columns = [u'foo', u'bar'] | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
def test_columns_dtypes_invalid(self, engine): | ||
|
||
df = pd.DataFrame({'string': list('abc'), | ||
'int': list(range(1, 4))}) | ||
|
||
|
@@ -302,17 +306,16 @@ def test_compression(self, engine, compression): | |
pytest.importorskip('brotli') | ||
|
||
df = pd.DataFrame({'A': [1, 2, 3]}) | ||
self.check_round_trip(df, engine, | ||
write_kwargs={'compression': compression}) | ||
check_round_trip(df, engine, write_kwargs={'compression': compression}) | ||
|
||
def test_read_columns(self, engine): | ||
# GH18154 | ||
df = pd.DataFrame({'string': list('abc'), | ||
'int': list(range(1, 4))}) | ||
|
||
expected = pd.DataFrame({'string': list('abc')}) | ||
self.check_round_trip(df, engine, expected=expected, | ||
read_kwargs={'columns': ['string']}) | ||
check_round_trip(df, engine, expected=expected, | ||
read_kwargs={'columns': ['string']}) | ||
|
||
def test_write_index(self, engine): | ||
check_names = engine != 'fastparquet' | ||
|
@@ -323,7 +326,7 @@ def test_write_index(self, engine): | |
pytest.skip("pyarrow is < 0.7.0") | ||
|
||
df = pd.DataFrame({'A': [1, 2, 3]}) | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
indexes = [ | ||
[2, 3, 4], | ||
|
@@ -334,12 +337,12 @@ def test_write_index(self, engine): | |
# non-default index | ||
for index in indexes: | ||
df.index = index | ||
self.check_round_trip(df, engine, check_names=check_names) | ||
check_round_trip(df, engine, check_names=check_names) | ||
|
||
# index with meta-data | ||
df.index = [0, 1, 2] | ||
df.index.name = 'foo' | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
def test_write_multiindex(self, pa_ge_070): | ||
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version | ||
|
@@ -348,7 +351,7 @@ def test_write_multiindex(self, pa_ge_070): | |
df = pd.DataFrame({'A': [1, 2, 3]}) | ||
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)]) | ||
df.index = index | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
def test_write_column_multiindex(self, engine): | ||
# column multi-index | ||
|
@@ -357,7 +360,6 @@ def test_write_column_multiindex(self, engine): | |
self.check_error_on_write(df, engine, ValueError) | ||
|
||
def test_multiindex_with_columns(self, pa_ge_070): | ||
|
||
engine = pa_ge_070 | ||
dates = pd.date_range('01-Jan-2018', '01-Dec-2018', freq='MS') | ||
df = pd.DataFrame(np.random.randn(2 * len(dates), 3), | ||
|
@@ -368,14 +370,10 @@ def test_multiindex_with_columns(self, pa_ge_070): | |
index2 = index1.copy(names=None) | ||
for index in [index1, index2]: | ||
df.index = index | ||
with tm.ensure_clean() as path: | ||
df.to_parquet(path, engine) | ||
result = read_parquet(path, engine) | ||
expected = df | ||
tm.assert_frame_equal(result, expected) | ||
result = read_parquet(path, engine, columns=['A', 'B']) | ||
expected = df[['A', 'B']] | ||
tm.assert_frame_equal(result, expected) | ||
|
||
check_round_trip(df, engine) | ||
check_round_trip(df, engine, read_kwargs={'columns': ['A', 'B']}, | ||
expected=df[['A', 'B']]) | ||
|
||
|
||
class TestParquetPyArrow(Base): | ||
|
@@ -391,7 +389,7 @@ def test_basic(self, pa, df_full): | |
tz='Europe/Brussels') | ||
df['bool_with_none'] = [True, None, True] | ||
|
||
self.check_round_trip(df, pa) | ||
check_round_trip(df, pa) | ||
|
||
@pytest.mark.xfail(reason="pyarrow fails on this (ARROW-1883)") | ||
def test_basic_subset_columns(self, pa, df_full): | ||
|
@@ -402,8 +400,8 @@ def test_basic_subset_columns(self, pa, df_full): | |
df['datetime_tz'] = pd.date_range('20130101', periods=3, | ||
tz='Europe/Brussels') | ||
|
||
self.check_round_trip(df, pa, expected=df[['string', 'int']], | ||
read_kwargs={'columns': ['string', 'int']}) | ||
check_round_trip(df, pa, expected=df[['string', 'int']], | ||
read_kwargs={'columns': ['string', 'int']}) | ||
|
||
def test_duplicate_columns(self, pa): | ||
# not currently able to handle duplicate columns | ||
|
@@ -433,7 +431,7 @@ def test_categorical(self, pa_ge_070): | |
|
||
# de-serialized as object | ||
expected = df.assign(a=df.a.astype(object)) | ||
self.check_round_trip(df, pa, expected) | ||
check_round_trip(df, pa, expected=expected) | ||
|
||
def test_categorical_unsupported(self, pa_lt_070): | ||
pa = pa_lt_070 | ||
|
@@ -444,20 +442,19 @@ def test_categorical_unsupported(self, pa_lt_070): | |
|
||
def test_s3_roundtrip(self, df_compat, s3_resource, pa): | ||
# GH #19134 | ||
self.check_round_trip(df_compat, pa, | ||
path='s3://pandas-test/pyarrow.parquet') | ||
check_round_trip(df_compat, pa, | ||
path='s3://pandas-test/pyarrow.parquet') | ||
|
||
|
||
class TestParquetFastParquet(Base): | ||
|
||
def test_basic(self, fp, df_full): | ||
|
||
df = df_full | ||
|
||
# additional supported types for fastparquet | ||
df['timedelta'] = pd.timedelta_range('1 day', periods=3) | ||
|
||
self.check_round_trip(df, fp) | ||
check_round_trip(df, fp) | ||
|
||
@pytest.mark.skip(reason="not supported") | ||
def test_duplicate_columns(self, fp): | ||
|
@@ -470,7 +467,7 @@ def test_duplicate_columns(self, fp): | |
def test_bool_with_none(self, fp): | ||
df = pd.DataFrame({'a': [True, None, False]}) | ||
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16') | ||
self.check_round_trip(df, fp, expected=expected) | ||
check_round_trip(df, fp, expected=expected) | ||
|
||
def test_unsupported(self, fp): | ||
|
||
|
@@ -486,7 +483,7 @@ def test_categorical(self, fp): | |
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"): | ||
pytest.skip("CategoricalDtype not supported for older fp") | ||
df = pd.DataFrame({'a': pd.Categorical(list('abc'))}) | ||
self.check_round_trip(df, fp) | ||
check_round_trip(df, fp) | ||
|
||
def test_datetime_tz(self, fp): | ||
# doesn't preserve tz | ||
|
@@ -495,7 +492,7 @@ def test_datetime_tz(self, fp): | |
|
||
# warns on the coercion | ||
with catch_warnings(record=True): | ||
self.check_round_trip(df, fp, df.astype('datetime64[ns]')) | ||
check_round_trip(df, fp, expected=df.astype('datetime64[ns]')) | ||
|
||
def test_filter_row_groups(self, fp): | ||
d = {'a': list(range(0, 3))} | ||
|
@@ -508,5 +505,5 @@ def test_filter_row_groups(self, fp): | |
|
||
def test_s3_roundtrip(self, df_compat, s3_resource, fp): | ||
# GH #19134 | ||
self.check_round_trip(df_compat, fp, | ||
path='s3://pandas-test/fastparquet.parquet') | ||
check_round_trip(df_compat, fp, | ||
path='s3://pandas-test/fastparquet.parquet') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls use the numpy-doc format (just at any other doc-string)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include Parameters, Raises sections