Skip to content

ENH: improve DataFrame read_csv / to_csv for Index/MultiIndex #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,9 @@ def from_csv(cls, path, header=0, delimiter=',', index_col=0):
header : int, default 0
Row to use at header (skip prior rows)
delimiter : string, default ','
index_col : int, default 0
Column to use for index
index_col : int or sequence, default 0
Column to use for index. If a sequence is given, a MultiIndex
is used.

Notes
-----
Expand Down Expand Up @@ -482,8 +483,10 @@ def to_csv(self, path, nanRep='', cols=None, header=True,
Write out column names
index : boolean, default True
Write row names (index)
index_label : string, default None
Column label for index column if desired
index_label : string or sequence, default None
Column label for index column(s) if desired. If None is given, and
`header` and `index` are True, then the index names are used. A
sequence should be given if the DataFrame uses MultiIndex.
mode : Python write mode, default 'wb'
"""
f = open(path, mode)
Expand All @@ -494,15 +497,25 @@ def to_csv(self, path, nanRep='', cols=None, header=True,
series = self._series
if header:
joined_cols = ','.join([str(c) for c in cols])
if index and index_label:
f.write('%s,%s' % (index_label, joined_cols))
if index:
# should write something for index label
if index_label is None:
index_label = getattr(self.index, 'names', ['index'])
elif not isinstance(index_label, (list, tuple, np.ndarray)):
# given a string for a DF with Index
index_label = [index_label]
f.write('%s,%s' % (",".join(index_label), joined_cols))
else:
f.write(joined_cols)
f.write('\n')

nlevels = getattr(self.index, 'nlevels', 1)
for idx in self.index:
if index:
f.write(str(idx))
if nlevels == 1:
f.write(str(idx))
else: # handle MultiIndex
f.write(",".join([str(i) for i in idx]))
for i, col in enumerate(cols):
val = series[col].get(idx)
if isnull(val):
Expand Down
68 changes: 49 additions & 19 deletions pandas/io/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np

from pandas.core.index import Index
from pandas.core.index import Index, MultiIndex
from pandas.core.frame import DataFrame

def read_csv(filepath_or_buffer, sep=None, header=0, skiprows=None, index_col=0,
Expand All @@ -27,9 +27,9 @@ def read_csv(filepath_or_buffer, sep=None, header=0, skiprows=None, index_col=0,
Row to use for the column labels of the parsed DataFrame
skiprows : list-like
Row numbers to skip (0-indexed)
index_col : int, default 0
index_col : int or sequence., default 0
Column to use as the row labels of the DataFrame. Pass None if there is
no such column
no such column. If a sequence is given, a MultiIndex is used.
na_values : list-like, default None
List of additional strings to recognize as NA/NaN
date_parser : function
Expand Down Expand Up @@ -65,7 +65,7 @@ def read_csv(filepath_or_buffer, sep=None, header=0, skiprows=None, index_col=0,
sniffed = csv.Sniffer().sniff(sample)
dia.delimiter = sniffed.delimiter
f.seek(0)

reader = csv.reader(f, dialect=dia)

if skiprows is not None:
Expand All @@ -92,9 +92,9 @@ def read_table(filepath_or_buffer, sep='\t', header=0, skiprows=None,
Row to use for the column labels of the parsed DataFrame
skiprows : list-like
Row numbers to skip (0-indexed)
index_col : int, default 0
index_col : int or sequence, default 0
Column to use as the row labels of the DataFrame. Pass None if there is
no such column
no such column. If a sequence is given, a MultiIndex is used.
na_values : list-like, default None
List of additional strings to recognize as NA/NaN
date_parser : function
Expand All @@ -107,7 +107,7 @@ def read_table(filepath_or_buffer, sep='\t', header=0, skiprows=None,
-------
parsed : DataFrame
"""
return read_csv(filepath_or_buffer, sep, header, skiprows,
return read_csv(filepath_or_buffer, sep, header, skiprows,
index_col, na_values, date_parser, names)

def _simple_parser(lines, colNames=None, header=0, indexCol=0,
Expand Down Expand Up @@ -149,27 +149,43 @@ def _simple_parser(lines, colNames=None, header=0, indexCol=0,

# no index column specified, so infer that's what is wanted
if indexCol is not None:
if indexCol == 0 and len(content[0]) == len(columns) + 1:
index = zipped_content[0]
zipped_content = zipped_content[1:]
if np.isscalar(indexCol):
if indexCol == 0 and len(content[0]) == len(columns) + 1:
index = zipped_content[0]
zipped_content = zipped_content[1:]
else:
index = zipped_content.pop(indexCol)
columns.pop(indexCol)
else: # given a list of index
idx_names = []
index = []
for idx in indexCol:
idx_names.append(columns[idx])
index.append(zipped_content[idx])
#remove index items from content and columns, don't pop in loop
for i in range(len(indexCol)):
columns.remove(idx_names[i])
zipped_content.remove(index[i])


if np.isscalar(indexCol):
if parse_dates:
index = _try_parse_dates(index, parser=date_parser)
index = Index(_maybe_convert_int(np.array(index, dtype=object)))
else:
index = zipped_content.pop(indexCol)
columns.pop(indexCol)

if parse_dates:
index = _try_parse_dates(index, parser=date_parser)

index = _maybe_convert_int(np.array(index, dtype=object))
index = MultiIndex.from_arrays(_maybe_convert_int_mindex(index,
parse_dates, date_parser),
names=idx_names)
else:
index = np.arange(len(content))
index = Index(np.arange(len(content)))

if len(columns) != len(zipped_content):
raise Exception('wrong number of columns')

data = dict(izip(columns, zipped_content))
data = _floatify(data, na_values=na_values)
data = _convert_to_ndarrays(data)
return DataFrame(data=data, columns=columns, index=Index(index))
return DataFrame(data=data, columns=columns, index=index)

def _floatify(data_dict, na_values=None):
"""
Expand Down Expand Up @@ -218,6 +234,20 @@ def _maybe_convert_int(arr):

return arr

def _maybe_convert_int_mindex(index, parse_dates, date_parser):
if len(index) == 0:
return index

for i in range(len(index)):
try:
int(index[i][0])
index[i] = map(int, index[i])
except ValueError:
if parse_dates:
index[i] = _try_parse_dates(index[i], date_parser)

return index

def _convert_to_ndarrays(dct):
result = {}
for c, values in dct.iteritems():
Expand Down
45 changes: 44 additions & 1 deletion pandas/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import pandas.core.datetools as datetools
from pandas.core.index import NULL_INDEX
from pandas.core.api import (DataFrame, Index, Series, notnull, isnull)
from pandas.core.api import (DataFrame, Index, Series, notnull, isnull,
MultiIndex)

from pandas.util.testing import (assert_almost_equal,
assert_series_equal,
Expand Down Expand Up @@ -1462,6 +1463,48 @@ def test_to_csv_from_csv(self):

os.remove(path)

def test_to_csv_multiindex(self):
path = '__tmp__'

frame = self.frame
old_index = frame.index
new_index = MultiIndex.from_arrays(np.arange(len(old_index)*2).reshape(2,-1))
frame.index = new_index
frame.to_csv(path, header=False)
frame.to_csv(path, cols=['A', 'B'])


# round trip
frame.to_csv(path)

df = DataFrame.from_csv(path, index_col=[0,1])

assert_frame_equal(frame, df)
self.frame.index = old_index # needed if setUP becomes a classmethod

# try multiindex with dates
tsframe = self.tsframe
old_index = tsframe.index
new_index = [old_index, np.arange(len(old_index))]
tsframe.index = MultiIndex.from_arrays(new_index)

tsframe.to_csv(path, index_label = ['time','foo'])
recons = DataFrame.from_csv(path, index_col=[0,1])
assert_frame_equal(tsframe, recons)

# do not load index
tsframe.to_csv(path)
recons = DataFrame.from_csv(path, index_col=None)
np.testing.assert_equal(len(recons.columns), len(tsframe.columns) + 2)

# no index
tsframe.to_csv(path, index=False)
recons = DataFrame.from_csv(path, index_col=None)
assert_almost_equal(recons.values, self.tsframe.values)
self.tsframe.index = old_index # needed if setUP becomes classmethod

os.remove(path)

def test_info(self):
io = StringIO()
self.frame.info(buf=io)
Expand Down