-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Refactor test_parquet.py to use check_round_trip at module level #19332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,48 +110,72 @@ def df_full(): | |
pd.Timestamp('20130103')]}) | ||
|
||
|
||
def test_invalid_engine(df_compat): | ||
def check_round_trip(df, engine=None, path=None, | ||
write_kwargs=None, read_kwargs=None, | ||
expected=None, check_names=True): | ||
|
||
with pytest.raises(ValueError): | ||
df_compat.to_parquet('foo', 'bar') | ||
if write_kwargs is None: | ||
write_kwargs = {'compression': None} | ||
|
||
if read_kwargs is None: | ||
read_kwargs = {} | ||
|
||
def test_options_py(df_compat, pa): | ||
# use the set option | ||
if engine: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a case where engine is not required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, mostly when testing for "defaults" handling (i.e. default engine, or engine being configured via |
||
write_kwargs['engine'] = engine | ||
read_kwargs['engine'] = engine | ||
|
||
df = df_compat | ||
with tm.ensure_clean() as path: | ||
if expected is None: | ||
expected = df | ||
|
||
with pd.option_context('io.parquet.engine', 'pyarrow'): | ||
df.to_parquet(path) | ||
if path is None: | ||
with tm.ensure_clean() as path: | ||
df.to_parquet(path, **write_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you make both of these into a loop with a kwarg in the signature
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks... Python is not native for me, I was looking for a nicer way to do that 👍 |
||
actual = read_parquet(path, **read_kwargs) | ||
tm.assert_frame_equal(expected, actual, | ||
check_names=check_names) | ||
|
||
# repeat | ||
df.to_parquet(path, **write_kwargs) | ||
actual = read_parquet(path, **read_kwargs) | ||
tm.assert_frame_equal(expected, actual, | ||
check_names=check_names) | ||
else: | ||
df.to_parquet(path, **write_kwargs) | ||
actual = read_parquet(path, **read_kwargs) | ||
tm.assert_frame_equal(expected, actual, | ||
check_names=check_names) | ||
|
||
# repeat | ||
df.to_parquet(path, **write_kwargs) | ||
actual = read_parquet(path, **read_kwargs) | ||
tm.assert_frame_equal(expected, actual, | ||
check_names=check_names) | ||
|
||
result = read_parquet(path) | ||
tm.assert_frame_equal(result, df) | ||
|
||
def test_invalid_engine(df_compat): | ||
with pytest.raises(ValueError): | ||
check_round_trip(df_compat, 'foo', 'bar') | ||
|
||
def test_options_fp(df_compat, fp): | ||
|
||
def test_options_py(df_compat, pa): | ||
# use the set option | ||
|
||
df = df_compat | ||
with tm.ensure_clean() as path: | ||
with pd.option_context('io.parquet.engine', 'pyarrow'): | ||
check_round_trip(df_compat) | ||
|
||
with pd.option_context('io.parquet.engine', 'fastparquet'): | ||
df.to_parquet(path, compression=None) | ||
|
||
result = read_parquet(path) | ||
tm.assert_frame_equal(result, df) | ||
def test_options_fp(df_compat, fp): | ||
# use the set option | ||
|
||
with pd.option_context('io.parquet.engine', 'fastparquet'): | ||
check_round_trip(df_compat) | ||
|
||
def test_options_auto(df_compat, fp, pa): | ||
|
||
df = df_compat | ||
with tm.ensure_clean() as path: | ||
|
||
with pd.option_context('io.parquet.engine', 'auto'): | ||
df.to_parquet(path) | ||
def test_options_auto(df_compat, fp, pa): | ||
# use the set option | ||
|
||
result = read_parquet(path) | ||
tm.assert_frame_equal(result, df) | ||
with pd.option_context('io.parquet.engine', 'auto'): | ||
check_round_trip(df_compat) | ||
|
||
|
||
def test_options_get_engine(fp, pa): | ||
|
@@ -228,53 +252,23 @@ def check_error_on_write(self, df, engine, exc): | |
with tm.ensure_clean() as path: | ||
to_parquet(df, path, engine, compression=None) | ||
|
||
def check_round_trip(self, df, engine, expected=None, path=None, | ||
write_kwargs=None, read_kwargs=None, | ||
check_names=True): | ||
|
||
if write_kwargs is None: | ||
write_kwargs = {'compression': None} | ||
|
||
if read_kwargs is None: | ||
read_kwargs = {} | ||
|
||
if expected is None: | ||
expected = df | ||
|
||
if path is None: | ||
with tm.ensure_clean() as path: | ||
check_round_trip_equals(df, path, engine, | ||
write_kwargs=write_kwargs, | ||
read_kwargs=read_kwargs, | ||
expected=expected, | ||
check_names=check_names) | ||
else: | ||
check_round_trip_equals(df, path, engine, | ||
write_kwargs=write_kwargs, | ||
read_kwargs=read_kwargs, | ||
expected=expected, | ||
check_names=check_names) | ||
|
||
|
||
class TestBasic(Base): | ||
|
||
def test_error(self, engine): | ||
|
||
for obj in [pd.Series([1, 2, 3]), 1, 'foo', pd.Timestamp('20130101'), | ||
np.array([1, 2, 3])]: | ||
self.check_error_on_write(obj, engine, ValueError) | ||
|
||
def test_columns_dtypes(self, engine): | ||
|
||
df = pd.DataFrame({'string': list('abc'), | ||
'int': list(range(1, 4))}) | ||
|
||
# unicode | ||
df.columns = [u'foo', u'bar'] | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
def test_columns_dtypes_invalid(self, engine): | ||
|
||
df = pd.DataFrame({'string': list('abc'), | ||
'int': list(range(1, 4))}) | ||
|
||
|
@@ -302,17 +296,16 @@ def test_compression(self, engine, compression): | |
pytest.importorskip('brotli') | ||
|
||
df = pd.DataFrame({'A': [1, 2, 3]}) | ||
self.check_round_trip(df, engine, | ||
write_kwargs={'compression': compression}) | ||
check_round_trip(df, engine, write_kwargs={'compression': compression}) | ||
|
||
def test_read_columns(self, engine): | ||
# GH18154 | ||
df = pd.DataFrame({'string': list('abc'), | ||
'int': list(range(1, 4))}) | ||
|
||
expected = pd.DataFrame({'string': list('abc')}) | ||
self.check_round_trip(df, engine, expected=expected, | ||
read_kwargs={'columns': ['string']}) | ||
check_round_trip(df, engine, expected=expected, | ||
read_kwargs={'columns': ['string']}) | ||
|
||
def test_write_index(self, engine): | ||
check_names = engine != 'fastparquet' | ||
|
@@ -323,7 +316,7 @@ def test_write_index(self, engine): | |
pytest.skip("pyarrow is < 0.7.0") | ||
|
||
df = pd.DataFrame({'A': [1, 2, 3]}) | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
indexes = [ | ||
[2, 3, 4], | ||
|
@@ -334,12 +327,12 @@ def test_write_index(self, engine): | |
# non-default index | ||
for index in indexes: | ||
df.index = index | ||
self.check_round_trip(df, engine, check_names=check_names) | ||
check_round_trip(df, engine, check_names=check_names) | ||
|
||
# index with meta-data | ||
df.index = [0, 1, 2] | ||
df.index.name = 'foo' | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
def test_write_multiindex(self, pa_ge_070): | ||
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version | ||
|
@@ -348,7 +341,7 @@ def test_write_multiindex(self, pa_ge_070): | |
df = pd.DataFrame({'A': [1, 2, 3]}) | ||
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)]) | ||
df.index = index | ||
self.check_round_trip(df, engine) | ||
check_round_trip(df, engine) | ||
|
||
def test_write_column_multiindex(self, engine): | ||
# column multi-index | ||
|
@@ -357,7 +350,6 @@ def test_write_column_multiindex(self, engine): | |
self.check_error_on_write(df, engine, ValueError) | ||
|
||
def test_multiindex_with_columns(self, pa_ge_070): | ||
|
||
engine = pa_ge_070 | ||
dates = pd.date_range('01-Jan-2018', '01-Dec-2018', freq='MS') | ||
df = pd.DataFrame(np.random.randn(2 * len(dates), 3), | ||
|
@@ -368,14 +360,10 @@ def test_multiindex_with_columns(self, pa_ge_070): | |
index2 = index1.copy(names=None) | ||
for index in [index1, index2]: | ||
df.index = index | ||
with tm.ensure_clean() as path: | ||
df.to_parquet(path, engine) | ||
result = read_parquet(path, engine) | ||
expected = df | ||
tm.assert_frame_equal(result, expected) | ||
result = read_parquet(path, engine, columns=['A', 'B']) | ||
expected = df[['A', 'B']] | ||
tm.assert_frame_equal(result, expected) | ||
|
||
check_round_trip(df, engine) | ||
check_round_trip(df, engine, read_kwargs={'columns': ['A', 'B']}, | ||
expected=df[['A', 'B']]) | ||
|
||
|
||
class TestParquetPyArrow(Base): | ||
|
@@ -391,7 +379,7 @@ def test_basic(self, pa, df_full): | |
tz='Europe/Brussels') | ||
df['bool_with_none'] = [True, None, True] | ||
|
||
self.check_round_trip(df, pa) | ||
check_round_trip(df, pa) | ||
|
||
@pytest.mark.xfail(reason="pyarrow fails on this (ARROW-1883)") | ||
def test_basic_subset_columns(self, pa, df_full): | ||
|
@@ -402,8 +390,8 @@ def test_basic_subset_columns(self, pa, df_full): | |
df['datetime_tz'] = pd.date_range('20130101', periods=3, | ||
tz='Europe/Brussels') | ||
|
||
self.check_round_trip(df, pa, expected=df[['string', 'int']], | ||
read_kwargs={'columns': ['string', 'int']}) | ||
check_round_trip(df, pa, expected=df[['string', 'int']], | ||
read_kwargs={'columns': ['string', 'int']}) | ||
|
||
def test_duplicate_columns(self, pa): | ||
# not currently able to handle duplicate columns | ||
|
@@ -433,7 +421,7 @@ def test_categorical(self, pa_ge_070): | |
|
||
# de-serialized as object | ||
expected = df.assign(a=df.a.astype(object)) | ||
self.check_round_trip(df, pa, expected) | ||
check_round_trip(df, pa, expected=expected) | ||
|
||
def test_categorical_unsupported(self, pa_lt_070): | ||
pa = pa_lt_070 | ||
|
@@ -444,20 +432,19 @@ def test_categorical_unsupported(self, pa_lt_070): | |
|
||
def test_s3_roundtrip(self, df_compat, s3_resource, pa): | ||
# GH #19134 | ||
self.check_round_trip(df_compat, pa, | ||
path='s3://pandas-test/pyarrow.parquet') | ||
check_round_trip(df_compat, pa, | ||
path='s3://pandas-test/pyarrow.parquet') | ||
|
||
|
||
class TestParquetFastParquet(Base): | ||
|
||
def test_basic(self, fp, df_full): | ||
|
||
df = df_full | ||
|
||
# additional supported types for fastparquet | ||
df['timedelta'] = pd.timedelta_range('1 day', periods=3) | ||
|
||
self.check_round_trip(df, fp) | ||
check_round_trip(df, fp) | ||
|
||
@pytest.mark.skip(reason="not supported") | ||
def test_duplicate_columns(self, fp): | ||
|
@@ -470,7 +457,7 @@ def test_duplicate_columns(self, fp): | |
def test_bool_with_none(self, fp): | ||
df = pd.DataFrame({'a': [True, None, False]}) | ||
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16') | ||
self.check_round_trip(df, fp, expected=expected) | ||
check_round_trip(df, fp, expected=expected) | ||
|
||
def test_unsupported(self, fp): | ||
|
||
|
@@ -486,7 +473,7 @@ def test_categorical(self, fp): | |
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"): | ||
pytest.skip("CategoricalDtype not supported for older fp") | ||
df = pd.DataFrame({'a': pd.Categorical(list('abc'))}) | ||
self.check_round_trip(df, fp) | ||
check_round_trip(df, fp) | ||
|
||
def test_datetime_tz(self, fp): | ||
# doesn't preserve tz | ||
|
@@ -495,7 +482,7 @@ def test_datetime_tz(self, fp): | |
|
||
# warns on the coercion | ||
with catch_warnings(record=True): | ||
self.check_round_trip(df, fp, df.astype('datetime64[ns]')) | ||
check_round_trip(df, fp, expected=df.astype('datetime64[ns]')) | ||
|
||
def test_filter_row_groups(self, fp): | ||
d = {'a': list(range(0, 3))} | ||
|
@@ -508,5 +495,5 @@ def test_filter_row_groups(self, fp): | |
|
||
def test_s3_roundtrip(self, df_compat, s3_resource, fp): | ||
# GH #19134 | ||
self.check_round_trip(df_compat, fp, | ||
path='s3://pandas-test/fastparquet.parquet') | ||
check_round_trip(df_compat, fp, | ||
path='s3://pandas-test/fastparquet.parquet') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a proper doc-string here