@@ -212,40 +212,38 @@ def check_error_on_write(self, df, engine, exc):
212
212
with tm .ensure_clean () as path :
213
213
to_parquet (df , path , engine , compression = None )
214
214
215
- def do_round_trip (self , df , path , engine_impl , expected = None ,
216
- write_kwargs = None , read_kwargs = None ,
217
- check_names = True ):
215
+ def check_round_trip (self , df , engine , expected = None , path = None ,
216
+ write_kwargs = None , read_kwargs = None ,
217
+ check_names = True ):
218
218
219
219
if write_kwargs is None :
220
220
write_kwargs = {'compression' : None }
221
221
222
222
if read_kwargs is None :
223
223
read_kwargs = {}
224
224
225
- df .to_parquet (path , engine_impl , ** write_kwargs )
226
- actual = read_parquet (path , engine_impl , ** read_kwargs )
227
-
228
225
if expected is None :
229
226
expected = df
230
227
231
- tm .assert_frame_equal (expected , actual , check_names = check_names )
232
-
233
- def check_round_trip (self , df , engine , expected = None ,
234
- write_kwargs = None , read_kwargs = None ,
235
- check_names = True ):
236
-
237
- with tm .ensure_clean () as path :
238
- self .do_round_trip (df , path , engine , expected ,
239
- write_kwargs = write_kwargs ,
240
- read_kwargs = read_kwargs ,
241
- check_names = check_names )
228
+ if path is None :
229
+ with tm .ensure_clean () as path :
230
+ df .to_parquet (path , engine , ** write_kwargs )
231
+ actual = read_parquet (path , engine , ** read_kwargs )
232
+ tm .assert_frame_equal (expected , actual , check_names = check_names )
233
+
234
+ # repeat
235
+ df .to_parquet (path , engine , ** write_kwargs )
236
+ actual = read_parquet (path , engine , ** read_kwargs )
237
+ tm .assert_frame_equal (expected , actual , check_names = check_names )
238
+ else :
239
+ df .to_parquet (path , engine , ** write_kwargs )
240
+ actual = read_parquet (path , engine , ** read_kwargs )
241
+ tm .assert_frame_equal (expected , actual , check_names = check_names )
242
242
243
243
# repeat
244
- self .do_round_trip (df , path , engine , expected ,
245
- write_kwargs = write_kwargs ,
246
- read_kwargs = read_kwargs ,
247
- check_names = check_names )
248
-
244
+ df .to_parquet (path , engine , ** write_kwargs )
245
+ actual = read_parquet (path , engine , ** read_kwargs )
246
+ tm .assert_frame_equal (expected , actual , check_names = check_names )
249
247
250
248
class TestBasic (Base ):
251
249
@@ -435,7 +433,7 @@ def test_categorical_unsupported(self, pa_lt_070):
435
433
436
434
def test_s3_roundtrip (self , df_compat , s3_resource , pa ):
437
435
# GH #19134
438
- self .do_round_trip (df_compat , 's3://pandas-test/pyarrow.parquet' , pa )
436
+ self .check_round_trip (df_compat , pa , path = 's3://pandas-test/pyarrow.parquet' )
439
437
440
438
441
439
class TestParquetFastParquet (Base ):
@@ -499,4 +497,4 @@ def test_filter_row_groups(self, fp):
499
497
def test_s3_roundtrip (self , df_compat , s3_resource , fp ):
500
498
# GH #19134
501
499
with pytest .raises (NotImplementedError ):
502
- self .do_round_trip (df_compat , 's3://pandas-test/fastparquet.parquet' , fp )
500
+ self .check_round_trip (df_compat , fp , path = 's3://pandas-test/fastparquet.parquet' )
0 commit comments