Skip to content

Commit 794cf94

Browse files
author
Trevor Bye
committed
Merge branch 'dtype-patch' of https://github.com/trevorbye/pandas into dtype-patch
2 parents ec0996c + 0039b50 commit 794cf94

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

pandas/core/frame.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4307,6 +4307,7 @@ def set_index(
43074307
"one-dimensional arrays."
43084308
)
43094309

4310+
current_dtype = None
43104311
missing: List[Optional[Hashable]] = []
43114312
for col in keys:
43124313
if isinstance(
@@ -4320,6 +4321,9 @@ def set_index(
43204321
# everything else gets tried as a key; see GH 24969
43214322
try:
43224323
found = col in self.columns
4324+
if found:
4325+
# get current dtype to preserve through index creation
4326+
current_dtype = self.dtypes.get(col).type
43234327
except TypeError:
43244328
raise TypeError(f"{err_msg}. Received column of type {type(col)}")
43254329
else:
@@ -4375,7 +4379,7 @@ def set_index(
43754379
f"received array of length {len(arrays[-1])}"
43764380
)
43774381

4378-
index = ensure_index_from_sequences(arrays, names)
4382+
index = ensure_index_from_sequences(arrays, names, current_dtype)
43794383

43804384
if verify_integrity and not index.is_unique:
43814385
duplicates = index[index.duplicated()].unique()
@@ -4550,9 +4554,6 @@ class max type
45504554

45514555
def _maybe_casted_values(index, labels=None):
45524556
values = index._values
4553-
if not isinstance(index, (PeriodIndex, DatetimeIndex)):
4554-
if values.dtype == np.object_:
4555-
values = lib.maybe_convert_objects(values)
45564557

45574558
# if we have the labels, extract the values with a mask
45584559
if labels is not None:

pandas/core/indexes/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5501,7 +5501,7 @@ def shape(self):
55015501
Index._add_comparison_methods()
55025502

55035503

5504-
def ensure_index_from_sequences(sequences, names=None):
5504+
def ensure_index_from_sequences(sequences, names=None, dtype=None):
55055505
"""
55065506
Construct an index from sequences of data.
55075507
@@ -5512,6 +5512,7 @@ def ensure_index_from_sequences(sequences, names=None):
55125512
----------
55135513
sequences : sequence of sequences
55145514
names : sequence of str
5515+
dtype : NumPy dtype
55155516
55165517
Returns
55175518
-------
@@ -5537,7 +5538,7 @@ def ensure_index_from_sequences(sequences, names=None):
55375538
if len(sequences) == 1:
55385539
if names is not None:
55395540
names = names[0]
5540-
return Index(sequences[0], name=names)
5541+
return Index(sequences[0], name=names, dtype=dtype)
55415542
else:
55425543
return MultiIndex.from_arrays(sequences, names=names)
55435544

pandas/tests/frame/test_alter_axes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,41 @@ def test_droplevel(self):
14861486
result = df.droplevel("level_2", axis="columns")
14871487
tm.assert_frame_equal(result, expected)
14881488

1489+
@pytest.mark.parametrize('test_dtype', [object, 'int64'])
1490+
def test_dtypes(self, test_dtype):
1491+
df = DataFrame({'A': Series([1, 2, 3], dtype=test_dtype), 'B': [1, 2, 3]})
1492+
expected = df.dtypes.values[0].type
1493+
1494+
result = df.set_index('A').index.dtype.type
1495+
assert result == expected
1496+
1497+
@pytest.fixture
1498+
def mixed_series(self):
1499+
return Series([1, 2, 3, 'apple', 'corn'], dtype=object)
1500+
1501+
@pytest.fixture
1502+
def int_series(self):
1503+
return Series([100, 200, 300, 400, 500])
1504+
1505+
def test_dtypes_between_queries(self, mixed_series, int_series):
1506+
df = DataFrame({'item': mixed_series, 'cost': int_series})
1507+
1508+
orig_dtypes = df.dtypes
1509+
item_dtype = orig_dtypes.get('item').type
1510+
cost_dtype = orig_dtypes.get('cost').type
1511+
expected = {'item': item_dtype, 'cost': cost_dtype}
1512+
1513+
# after applying a query that would remove strings from the 'item' series with
1514+
# dtype: object, that series should remain as dtype: object as it becomes an
1515+
# index, and again as it becomes a column again after calling reset_index()
1516+
dtypes_transformed = df.query('cost < 400').set_index(
1517+
'item').reset_index().dtypes
1518+
item_dtype_transformed = dtypes_transformed.get('item').type
1519+
cost_dtype_transformed = dtypes_transformed.get('cost').type
1520+
result = {'item': item_dtype_transformed, 'cost': cost_dtype_transformed}
1521+
1522+
assert result == expected
1523+
14891524

14901525
class TestIntervalIndex:
14911526
def test_setitem(self):

0 commit comments

Comments
 (0)