Skip to content

Commit ff134cc

Browse files
[ArrowStringArray] TST: parametrize str.extractall tests (#41419)
1 parent b472080 commit ff134cc

File tree

1 file changed

+140
-112
lines changed

1 file changed

+140
-112
lines changed

pandas/tests/strings/test_extract.py

Lines changed: 140 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,8 @@ def test_extract_single_group_returns_frame():
358358
tm.assert_frame_equal(r, e)
359359

360360

361-
def test_extractall():
362-
subject_list = [
361+
def test_extractall(any_string_dtype):
362+
data = [
363363
364364
365365
@@ -378,28 +378,30 @@ def test_extractall():
378378
("c", "d", "com"),
379379
("e", "f", "com"),
380380
]
381-
named_pattern = r"""
381+
pat = r"""
382382
(?P<user>[a-z0-9]+)
383383
@
384384
(?P<domain>[a-z]+)
385385
\.
386386
(?P<tld>[a-z]{2,4})
387387
"""
388388
expected_columns = ["user", "domain", "tld"]
389-
S = Series(subject_list)
390-
# extractall should return a DataFrame with one row for each
391-
# match, indexed by the subject from which the match came.
389+
s = Series(data, dtype=any_string_dtype)
390+
# extractall should return a DataFrame with one row for each match, indexed by the
391+
# subject from which the match came.
392392
expected_index = MultiIndex.from_tuples(
393393
[(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (4, 0), (4, 1), (4, 2)],
394394
names=(None, "match"),
395395
)
396-
expected_df = DataFrame(expected_tuples, expected_index, expected_columns)
397-
computed_df = S.str.extractall(named_pattern, re.VERBOSE)
398-
tm.assert_frame_equal(computed_df, expected_df)
396+
expected = DataFrame(
397+
expected_tuples, expected_index, expected_columns, dtype=any_string_dtype
398+
)
399+
result = s.str.extractall(pat, flags=re.VERBOSE)
400+
tm.assert_frame_equal(result, expected)
399401

400-
# The index of the input Series should be used to construct
401-
# the index of the output DataFrame:
402-
series_index = MultiIndex.from_tuples(
402+
# The index of the input Series should be used to construct the index of the output
403+
# DataFrame:
404+
mi = MultiIndex.from_tuples(
403405
[
404406
("single", "Dave"),
405407
("single", "Toby"),
@@ -410,7 +412,7 @@ def test_extractall():
410412
("none", "empty"),
411413
]
412414
)
413-
Si = Series(subject_list, series_index)
415+
s = Series(data, index=mi, dtype=any_string_dtype)
414416
expected_index = MultiIndex.from_tuples(
415417
[
416418
("single", "Dave", 0),
@@ -424,67 +426,80 @@ def test_extractall():
424426
],
425427
names=(None, None, "match"),
426428
)
427-
expected_df = DataFrame(expected_tuples, expected_index, expected_columns)
428-
computed_df = Si.str.extractall(named_pattern, re.VERBOSE)
429-
tm.assert_frame_equal(computed_df, expected_df)
429+
expected = DataFrame(
430+
expected_tuples, expected_index, expected_columns, dtype=any_string_dtype
431+
)
432+
result = s.str.extractall(pat, flags=re.VERBOSE)
433+
tm.assert_frame_equal(result, expected)
430434

431435
# MultiIndexed subject with names.
432-
Sn = Series(subject_list, series_index)
433-
Sn.index.names = ("matches", "description")
436+
s = Series(data, index=mi, dtype=any_string_dtype)
437+
s.index.names = ("matches", "description")
434438
expected_index.names = ("matches", "description", "match")
435-
expected_df = DataFrame(expected_tuples, expected_index, expected_columns)
436-
computed_df = Sn.str.extractall(named_pattern, re.VERBOSE)
437-
tm.assert_frame_equal(computed_df, expected_df)
438-
439-
# optional groups.
440-
subject_list = ["", "A1", "32"]
441-
named_pattern = "(?P<letter>[AB])?(?P<number>[123])"
442-
computed_df = Series(subject_list).str.extractall(named_pattern)
443-
expected_index = MultiIndex.from_tuples(
444-
[(1, 0), (2, 0), (2, 1)], names=(None, "match")
445-
)
446-
expected_df = DataFrame(
447-
[("A", "1"), (np.nan, "3"), (np.nan, "2")],
448-
expected_index,
449-
columns=["letter", "number"],
439+
expected = DataFrame(
440+
expected_tuples, expected_index, expected_columns, dtype=any_string_dtype
450441
)
451-
tm.assert_frame_equal(computed_df, expected_df)
442+
result = s.str.extractall(pat, flags=re.VERBOSE)
443+
tm.assert_frame_equal(result, expected)
444+
445+
446+
@pytest.mark.parametrize(
447+
"pat,expected_names",
448+
[
449+
# optional groups.
450+
("(?P<letter>[AB])?(?P<number>[123])", ["letter", "number"]),
451+
# only one of two groups has a name.
452+
("([AB])?(?P<number>[123])", [0, "number"]),
453+
],
454+
)
455+
def test_extractall_column_names(pat, expected_names, any_string_dtype):
456+
s = Series(["", "A1", "32"], dtype=any_string_dtype)
452457

453-
# only one of two groups has a name.
454-
pattern = "([AB])?(?P<number>[123])"
455-
computed_df = Series(subject_list).str.extractall(pattern)
456-
expected_df = DataFrame(
458+
result = s.str.extractall(pat)
459+
expected = DataFrame(
457460
[("A", "1"), (np.nan, "3"), (np.nan, "2")],
458-
expected_index,
459-
columns=[0, "number"],
461+
index=MultiIndex.from_tuples([(1, 0), (2, 0), (2, 1)], names=(None, "match")),
462+
columns=expected_names,
463+
dtype=any_string_dtype,
460464
)
461-
tm.assert_frame_equal(computed_df, expected_df)
465+
tm.assert_frame_equal(result, expected)
462466

463467

464-
def test_extractall_single_group():
465-
# extractall(one named group) returns DataFrame with one named
466-
# column.
467-
s = Series(["a3", "b3", "d4c2"], name="series_name")
468-
r = s.str.extractall(r"(?P<letter>[a-z])")
469-
i = MultiIndex.from_tuples([(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match"))
470-
e = DataFrame({"letter": ["a", "b", "d", "c"]}, i)
471-
tm.assert_frame_equal(r, e)
468+
def test_extractall_single_group(any_string_dtype):
469+
s = Series(["a3", "b3", "d4c2"], name="series_name", dtype=any_string_dtype)
470+
expected_index = MultiIndex.from_tuples(
471+
[(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match")
472+
)
472473

473-
# extractall(one un-named group) returns DataFrame with one
474-
# un-named column.
475-
r = s.str.extractall(r"([a-z])")
476-
e = DataFrame(["a", "b", "d", "c"], i)
477-
tm.assert_frame_equal(r, e)
474+
# extractall(one named group) returns DataFrame with one named column.
475+
result = s.str.extractall(r"(?P<letter>[a-z])")
476+
expected = DataFrame(
477+
{"letter": ["a", "b", "d", "c"]}, index=expected_index, dtype=any_string_dtype
478+
)
479+
tm.assert_frame_equal(result, expected)
480+
481+
# extractall(one un-named group) returns DataFrame with one un-named column.
482+
result = s.str.extractall(r"([a-z])")
483+
expected = DataFrame(
484+
["a", "b", "d", "c"], index=expected_index, dtype=any_string_dtype
485+
)
486+
tm.assert_frame_equal(result, expected)
478487

479488

480-
def test_extractall_single_group_with_quantifier():
481-
# extractall(one un-named group with quantifier) returns
482-
# DataFrame with one un-named column (GH13382).
483-
s = Series(["ab3", "abc3", "d4cd2"], name="series_name")
484-
r = s.str.extractall(r"([a-z]+)")
485-
i = MultiIndex.from_tuples([(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match"))
486-
e = DataFrame(["ab", "abc", "d", "cd"], i)
487-
tm.assert_frame_equal(r, e)
489+
def test_extractall_single_group_with_quantifier(any_string_dtype):
490+
# GH#13382
491+
# extractall(one un-named group with quantifier) returns DataFrame with one un-named
492+
# column.
493+
s = Series(["ab3", "abc3", "d4cd2"], name="series_name", dtype=any_string_dtype)
494+
result = s.str.extractall(r"([a-z]+)")
495+
expected = DataFrame(
496+
["ab", "abc", "d", "cd"],
497+
index=MultiIndex.from_tuples(
498+
[(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match")
499+
),
500+
dtype=any_string_dtype,
501+
)
502+
tm.assert_frame_equal(result, expected)
488503

489504

490505
@pytest.mark.parametrize(
@@ -500,78 +515,91 @@ def test_extractall_single_group_with_quantifier():
500515
(["a3", "b3", "d4c2"], ("i1", "i2")),
501516
],
502517
)
503-
def test_extractall_no_matches(data, names):
518+
def test_extractall_no_matches(data, names, any_string_dtype):
504519
# GH19075 extractall with no matches should return a valid MultiIndex
505520
n = len(data)
506521
if len(names) == 1:
507-
i = Index(range(n), name=names[0])
522+
index = Index(range(n), name=names[0])
508523
else:
509-
a = (tuple([i] * (n - 1)) for i in range(n))
510-
i = MultiIndex.from_tuples(a, names=names)
511-
s = Series(data, name="series_name", index=i, dtype="object")
512-
ei = MultiIndex.from_tuples([], names=(names + ("match",)))
524+
tuples = (tuple([i] * (n - 1)) for i in range(n))
525+
index = MultiIndex.from_tuples(tuples, names=names)
526+
s = Series(data, name="series_name", index=index, dtype=any_string_dtype)
527+
expected_index = MultiIndex.from_tuples([], names=(names + ("match",)))
513528

514529
# one un-named group.
515-
r = s.str.extractall("(z)")
516-
e = DataFrame(columns=[0], index=ei)
517-
tm.assert_frame_equal(r, e)
530+
result = s.str.extractall("(z)")
531+
expected = DataFrame(columns=[0], index=expected_index, dtype=any_string_dtype)
532+
tm.assert_frame_equal(result, expected)
518533

519534
# two un-named groups.
520-
r = s.str.extractall("(z)(z)")
521-
e = DataFrame(columns=[0, 1], index=ei)
522-
tm.assert_frame_equal(r, e)
535+
result = s.str.extractall("(z)(z)")
536+
expected = DataFrame(columns=[0, 1], index=expected_index, dtype=any_string_dtype)
537+
tm.assert_frame_equal(result, expected)
523538

524539
# one named group.
525-
r = s.str.extractall("(?P<first>z)")
526-
e = DataFrame(columns=["first"], index=ei)
527-
tm.assert_frame_equal(r, e)
540+
result = s.str.extractall("(?P<first>z)")
541+
expected = DataFrame(
542+
columns=["first"], index=expected_index, dtype=any_string_dtype
543+
)
544+
tm.assert_frame_equal(result, expected)
528545

529546
# two named groups.
530-
r = s.str.extractall("(?P<first>z)(?P<second>z)")
531-
e = DataFrame(columns=["first", "second"], index=ei)
532-
tm.assert_frame_equal(r, e)
547+
result = s.str.extractall("(?P<first>z)(?P<second>z)")
548+
expected = DataFrame(
549+
columns=["first", "second"], index=expected_index, dtype=any_string_dtype
550+
)
551+
tm.assert_frame_equal(result, expected)
533552

534553
# one named, one un-named.
535-
r = s.str.extractall("(z)(?P<second>z)")
536-
e = DataFrame(columns=[0, "second"], index=ei)
537-
tm.assert_frame_equal(r, e)
554+
result = s.str.extractall("(z)(?P<second>z)")
555+
expected = DataFrame(
556+
columns=[0, "second"], index=expected_index, dtype=any_string_dtype
557+
)
558+
tm.assert_frame_equal(result, expected)
538559

539560

540-
def test_extractall_stringindex():
541-
s = Series(["a1a2", "b1", "c1"], name="xxx")
542-
res = s.str.extractall(r"[ab](?P<digit>\d)")
543-
exp_idx = MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0)], names=[None, "match"])
544-
exp = DataFrame({"digit": ["1", "2", "1"]}, index=exp_idx)
545-
tm.assert_frame_equal(res, exp)
561+
def test_extractall_stringindex(any_string_dtype):
562+
s = Series(["a1a2", "b1", "c1"], name="xxx", dtype=any_string_dtype)
563+
result = s.str.extractall(r"[ab](?P<digit>\d)")
564+
expected = DataFrame(
565+
{"digit": ["1", "2", "1"]},
566+
index=MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0)], names=[None, "match"]),
567+
dtype=any_string_dtype,
568+
)
569+
tm.assert_frame_equal(result, expected)
546570

547-
# index should return the same result as the default index without name
548-
# thus index.name doesn't affect to the result
549-
for idx in [
550-
Index(["a1a2", "b1", "c1"]),
551-
Index(["a1a2", "b1", "c1"], name="xxx"),
552-
]:
571+
# index should return the same result as the default index without name thus
572+
# index.name doesn't affect to the result
573+
if any_string_dtype == "object":
574+
for idx in [
575+
Index(["a1a2", "b1", "c1"]),
576+
Index(["a1a2", "b1", "c1"], name="xxx"),
577+
]:
553578

554-
res = idx.str.extractall(r"[ab](?P<digit>\d)")
555-
tm.assert_frame_equal(res, exp)
579+
result = idx.str.extractall(r"[ab](?P<digit>\d)")
580+
tm.assert_frame_equal(result, expected)
556581

557582
s = Series(
558583
["a1a2", "b1", "c1"],
559584
name="s_name",
560585
index=Index(["XX", "yy", "zz"], name="idx_name"),
586+
dtype=any_string_dtype,
561587
)
562-
res = s.str.extractall(r"[ab](?P<digit>\d)")
563-
exp_idx = MultiIndex.from_tuples(
564-
[("XX", 0), ("XX", 1), ("yy", 0)], names=["idx_name", "match"]
588+
result = s.str.extractall(r"[ab](?P<digit>\d)")
589+
expected = DataFrame(
590+
{"digit": ["1", "2", "1"]},
591+
index=MultiIndex.from_tuples(
592+
[("XX", 0), ("XX", 1), ("yy", 0)], names=["idx_name", "match"]
593+
),
594+
dtype=any_string_dtype,
565595
)
566-
exp = DataFrame({"digit": ["1", "2", "1"]}, index=exp_idx)
567-
tm.assert_frame_equal(res, exp)
596+
tm.assert_frame_equal(result, expected)
568597

569598

570-
def test_extractall_errors():
571-
# Does not make sense to use extractall with a regex that has
572-
# no capture groups. (it returns DataFrame with one column for
573-
# each capture group)
574-
s = Series(["a3", "b3", "d4c2"], name="series_name")
599+
def test_extractall_no_capture_groups_raises(any_string_dtype):
600+
# Does not make sense to use extractall with a regex that has no capture groups.
601+
# (it returns DataFrame with one column for each capture group)
602+
s = Series(["a3", "b3", "d4c2"], name="series_name", dtype=any_string_dtype)
575603
with pytest.raises(ValueError, match="no capture groups"):
576604
s.str.extractall(r"[a-z]")
577605

@@ -591,8 +619,8 @@ def test_extract_index_one_two_groups():
591619
tm.assert_frame_equal(r, e)
592620

593621

594-
def test_extractall_same_as_extract():
595-
s = Series(["a3", "b3", "c2"], name="series_name")
622+
def test_extractall_same_as_extract(any_string_dtype):
623+
s = Series(["a3", "b3", "c2"], name="series_name", dtype=any_string_dtype)
596624

597625
pattern_two_noname = r"([a-z])([0-9])"
598626
extract_two_noname = s.str.extract(pattern_two_noname, expand=True)
@@ -619,13 +647,13 @@ def test_extractall_same_as_extract():
619647
tm.assert_frame_equal(extract_one_noname, no_multi_index)
620648

621649

622-
def test_extractall_same_as_extract_subject_index():
650+
def test_extractall_same_as_extract_subject_index(any_string_dtype):
623651
# same as above tests, but s has an MultiIndex.
624-
i = MultiIndex.from_tuples(
652+
mi = MultiIndex.from_tuples(
625653
[("A", "first"), ("B", "second"), ("C", "third")],
626654
names=("capital", "ordinal"),
627655
)
628-
s = Series(["a3", "b3", "c2"], i, name="series_name")
656+
s = Series(["a3", "b3", "c2"], index=mi, name="series_name", dtype=any_string_dtype)
629657

630658
pattern_two_noname = r"([a-z])([0-9])"
631659
extract_two_noname = s.str.extract(pattern_two_noname, expand=True)

0 commit comments

Comments
 (0)