Skip to content

Commit 824e96b

Browse files
committed
TST/CLN: break up & parametrize tests for df.set_index
1 parent 25e6a21 commit 824e96b

File tree

3 files changed

+572
-402
lines changed

3 files changed

+572
-402
lines changed

pandas/core/frame.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3862,10 +3862,29 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
38623862
-------
38633863
dataframe : DataFrame
38643864
"""
3865-
inplace = validate_bool_kwarg(inplace, 'inplace')
3865+
from pandas import Series
3866+
38663867
if not isinstance(keys, list):
38673868
keys = [keys]
38683869

3870+
# collect elements from "keys" that are not allowed array types
3871+
col_labels = [x for x in keys
3872+
if not isinstance(x, (Series, Index, MultiIndex,
3873+
list, np.ndarray))]
3874+
if any(x not in self for x in col_labels):
3875+
# if there are any labels that are invalid, we raise a KeyError
3876+
missing = [x for x in col_labels if x not in self]
3877+
raise KeyError('{}'.format(missing))
3878+
3879+
elif len(set(col_labels)) < len(col_labels):
3880+
# if all are valid labels, but there are duplicates
3881+
dup = Series(col_labels)
3882+
dup = list(dup.loc[dup.duplicated()])
3883+
raise ValueError('Passed duplicate column names '
3884+
'to keys: {dup}'.format(dup=dup))
3885+
3886+
inplace = validate_bool_kwarg(inplace, 'inplace')
3887+
38693888
if inplace:
38703889
frame = self
38713890
else:

pandas/tests/frame/conftest.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pytest
2+
3+
import numpy as np
4+
5+
from pandas import compat
6+
import pandas.util.testing as tm
7+
from pandas import DataFrame, date_range, NaT
8+
9+
10+
@pytest.fixture
11+
def frame():
12+
return DataFrame(tm.getSeriesData())
13+
14+
15+
@pytest.fixture
16+
def frame2():
17+
return DataFrame(tm.getSeriesData(), columns=['D', 'C', 'B', 'A'])
18+
19+
20+
@pytest.fixture
21+
def intframe():
22+
df = DataFrame({k: v.astype(int)
23+
for k, v in compat.iteritems(tm.getSeriesData())})
24+
# force these all to int64 to avoid platform testing issues
25+
return DataFrame({c: s for c, s in compat.iteritems(df)}, dtype=np.int64)
26+
27+
28+
@pytest.fixture
29+
def tsframe():
30+
return DataFrame(tm.getTimeSeriesData())
31+
32+
33+
@pytest.fixture
34+
def mixed_frame():
35+
df = DataFrame(tm.getSeriesData())
36+
df['foo'] = 'bar'
37+
return df
38+
39+
40+
@pytest.fixture
41+
def mixed_float():
42+
df = DataFrame(tm.getSeriesData())
43+
df.A = df.A.astype('float16')
44+
df.B = df.B.astype('float32')
45+
df.C = df.C.astype('float64')
46+
return df
47+
48+
49+
@pytest.fixture
50+
def mixed_float2():
51+
df = DataFrame(tm.getSeriesData())
52+
df.D = df.D.astype('float16')
53+
df.C = df.C.astype('float32')
54+
df.B = df.B.astype('float64')
55+
return df
56+
57+
58+
@pytest.fixture
59+
def mixed_int():
60+
df = DataFrame({k: v.astype(int)
61+
for k, v in compat.iteritems(tm.getSeriesData())})
62+
df.A = df.A.astype('uint8')
63+
df.B = df.B.astype('int32')
64+
df.C = df.C.astype('int64')
65+
df.D = np.ones(len(df.D), dtype='uint64')
66+
return df
67+
68+
69+
@pytest.fixture
70+
def all_mixed():
71+
return DataFrame({'a': 1., 'b': 2, 'c': 'foo',
72+
'float32': np.array([1.] * 10, dtype='float32'),
73+
'int32': np.array([1] * 10, dtype='int32')},
74+
index=np.arange(10))
75+
76+
77+
@pytest.fixture
78+
def tzframe():
79+
df = DataFrame({'A': date_range('20130101', periods=3),
80+
'B': date_range('20130101', periods=3,
81+
tz='US/Eastern'),
82+
'C': date_range('20130101', periods=3,
83+
tz='CET')})
84+
df.iloc[1, 1] = NaT
85+
df.iloc[1, 2] = NaT
86+
return df
87+
88+
89+
@pytest.fixture
90+
def empty():
91+
return DataFrame({})
92+
93+
94+
@pytest.fixture
95+
def ts1():
96+
return tm.makeTimeSeries(nper=30)
97+
98+
99+
@pytest.fixture
100+
def ts2():
101+
return tm.makeTimeSeries(nper=30)[5:]
102+
103+
104+
@pytest.fixture
105+
def simple():
106+
arr = np.array([[1., 2., 3.],
107+
[4., 5., 6.],
108+
[7., 8., 9.]])
109+
110+
return DataFrame(arr, columns=['one', 'two', 'three'],
111+
index=['a', 'b', 'c'])
112+
113+
114+
@pytest.fixture
115+
def frame_of_index_cols():
116+
df = DataFrame({'A': ['foo', 'foo', 'foo', 'bar', 'bar'],
117+
'B': ['one', 'two', 'three', 'one', 'two'],
118+
'C': ['a', 'b', 'c', 'd', 'e'],
119+
'D': np.random.randn(5),
120+
'E': np.random.randn(5)})
121+
return df

0 commit comments

Comments
 (0)