Skip to content

Commit c3fc9bb

Browse files
rohanjain101Rohan Jain
and
Rohan Jain
authored
raise error on unsafe decimal parse with pyarrow types (#56985)
* raise error on unsafe decimal parse with pyarrow types * fix min versions * restore typeerrro * success --------- Co-authored-by: Rohan Jain <[email protected]>
1 parent fbec88c commit c3fc9bb

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,7 @@ def _box_pa_array(
527527
else:
528528
try:
529529
pa_array = pa_array.cast(pa_type)
530-
except (
531-
pa.ArrowInvalid,
532-
pa.ArrowTypeError,
533-
pa.ArrowNotImplementedError,
534-
):
530+
except (pa.ArrowNotImplementedError, pa.ArrowTypeError):
535531
if pa.types.is_string(pa_array.type) or pa.types.is_large_string(
536532
pa_array.type
537533
):

pandas/tests/extension/test_arrow.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3203,6 +3203,30 @@ def test_pow_missing_operand():
32033203
tm.assert_series_equal(result, expected)
32043204

32053205

3206+
@pytest.mark.skipif(
3207+
pa_version_under11p0, reason="Decimal128 to string cast implemented in pyarrow 11"
3208+
)
3209+
def test_decimal_parse_raises():
3210+
# GH 56984
3211+
ser = pd.Series(["1.2345"], dtype=ArrowDtype(pa.string()))
3212+
with pytest.raises(
3213+
pa.lib.ArrowInvalid, match="Rescaling Decimal128 value would cause data loss"
3214+
):
3215+
ser.astype(ArrowDtype(pa.decimal128(1, 0)))
3216+
3217+
3218+
@pytest.mark.skipif(
3219+
pa_version_under11p0, reason="Decimal128 to string cast implemented in pyarrow 11"
3220+
)
3221+
def test_decimal_parse_succeeds():
3222+
# GH 56984
3223+
ser = pd.Series(["1.2345"], dtype=ArrowDtype(pa.string()))
3224+
dtype = ArrowDtype(pa.decimal128(5, 4))
3225+
result = ser.astype(dtype)
3226+
expected = pd.Series([Decimal("1.2345")], dtype=dtype)
3227+
tm.assert_series_equal(result, expected)
3228+
3229+
32063230
@pytest.mark.parametrize("pa_type", tm.TIMEDELTA_PYARROW_DTYPES)
32073231
def test_duration_fillna_numpy(pa_type):
32083232
# GH 54707

0 commit comments

Comments
 (0)