Skip to content

Commit f8e94fa

Browse files
committed
Merge pull request #7529 from sinhrks/dtibool
BUG: DatetimeIndex comparison handles NaT incorrectly
2 parents 441a1f2 + 589d30a commit f8e94fa

File tree

4 files changed

+116
-20
lines changed

4 files changed

+116
-20
lines changed

doc/source/v0.14.1.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ Bug Fixes
237237
- Bug in when writing Stata files where the encoding was ignored (:issue:`7286`)
238238

239239

240+
- Bug in ``DatetimeIndex`` comparison doesn't handle ``NaT`` properly (:issue:`7529`)
240241

241242

242243
- Bug in passing input with ``tzinfo`` to some offsets ``apply``, ``rollforward`` or ``rollback`` resets ``tzinfo`` or raises ``ValueError`` (:issue:`7465`)

pandas/tseries/index.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,35 @@ def wrapper(left, right):
7474
return wrapper
7575

7676

77-
def _dt_index_cmp(opname):
77+
def _dt_index_cmp(opname, nat_result=False):
7878
"""
7979
Wrap comparison operations to convert datetime-like to datetime64
8080
"""
8181
def wrapper(self, other):
8282
func = getattr(super(DatetimeIndex, self), opname)
83-
if isinstance(other, datetime):
83+
if isinstance(other, datetime) or isinstance(other, compat.string_types):
8484
other = _to_m8(other, tz=self.tz)
85-
elif isinstance(other, list):
86-
other = DatetimeIndex(other)
87-
elif isinstance(other, compat.string_types):
88-
other = _to_m8(other, tz=self.tz)
89-
elif not isinstance(other, (np.ndarray, ABCSeries)):
90-
other = _ensure_datetime64(other)
91-
result = func(other)
85+
result = func(other)
86+
if com.isnull(other):
87+
result.fill(nat_result)
88+
else:
89+
if isinstance(other, list):
90+
other = DatetimeIndex(other)
91+
elif not isinstance(other, (np.ndarray, ABCSeries)):
92+
other = _ensure_datetime64(other)
93+
result = func(other)
9294

95+
if isinstance(other, Index):
96+
o_mask = other.values.view('i8') == tslib.iNaT
97+
else:
98+
o_mask = other.view('i8') == tslib.iNaT
99+
100+
if o_mask.any():
101+
result[o_mask] = nat_result
102+
103+
mask = self.asi8 == tslib.iNaT
104+
if mask.any():
105+
result[mask] = nat_result
93106
return result.view(np.ndarray)
94107

95108
return wrapper
@@ -142,7 +155,7 @@ class DatetimeIndex(DatetimeIndexOpsMixin, Int64Index):
142155
_arrmap = None
143156

144157
__eq__ = _dt_index_cmp('__eq__')
145-
__ne__ = _dt_index_cmp('__ne__')
158+
__ne__ = _dt_index_cmp('__ne__', nat_result=True)
146159
__lt__ = _dt_index_cmp('__lt__')
147160
__gt__ = _dt_index_cmp('__gt__')
148161
__le__ = _dt_index_cmp('__le__')

pandas/tseries/period.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -498,16 +498,11 @@ def dt64arr_to_periodarr(data, freq, tz):
498498

499499
# --- Period index sketch
500500

501-
def _period_index_cmp(opname):
501+
def _period_index_cmp(opname, nat_result=False):
502502
"""
503503
Wrap comparison operations to convert datetime-like to datetime64
504504
"""
505505
def wrapper(self, other):
506-
if opname == '__ne__':
507-
fill_value = True
508-
else:
509-
fill_value = False
510-
511506
if isinstance(other, Period):
512507
func = getattr(self.values, opname)
513508
if other.freq != self.freq:
@@ -523,7 +518,7 @@ def wrapper(self, other):
523518
mask = (com.mask_missing(self.values, tslib.iNaT) |
524519
com.mask_missing(other.values, tslib.iNaT))
525520
if mask.any():
526-
result[mask] = fill_value
521+
result[mask] = nat_result
527522

528523
return result
529524
else:
@@ -532,10 +527,10 @@ def wrapper(self, other):
532527
result = func(other.ordinal)
533528

534529
if other.ordinal == tslib.iNaT:
535-
result.fill(fill_value)
530+
result.fill(nat_result)
536531
mask = self.values == tslib.iNaT
537532
if mask.any():
538-
result[mask] = fill_value
533+
result[mask] = nat_result
539534

540535
return result
541536
return wrapper
@@ -595,7 +590,7 @@ class PeriodIndex(DatetimeIndexOpsMixin, Int64Index):
595590
_allow_period_index_ops = True
596591

597592
__eq__ = _period_index_cmp('__eq__')
598-
__ne__ = _period_index_cmp('__ne__')
593+
__ne__ = _period_index_cmp('__ne__', nat_result=True)
599594
__lt__ = _period_index_cmp('__lt__')
600595
__gt__ = _period_index_cmp('__gt__')
601596
__le__ = _period_index_cmp('__le__')

pandas/tseries/tests/test_timeseries.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,6 +2179,93 @@ def test_comparisons_coverage(self):
21792179
exp = rng == rng
21802180
self.assert_numpy_array_equal(result, exp)
21812181

2182+
def test_comparisons_nat(self):
2183+
fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0])
2184+
fidx2 = pd.Index([2.0, 3.0, np.nan, np.nan, 6.0, 7.0])
2185+
2186+
didx1 = pd.DatetimeIndex(['2014-01-01', pd.NaT, '2014-03-01', pd.NaT,
2187+
'2014-05-01', '2014-07-01'])
2188+
didx2 = pd.DatetimeIndex(['2014-02-01', '2014-03-01', pd.NaT, pd.NaT,
2189+
'2014-06-01', '2014-07-01'])
2190+
darr = np.array([np.datetime64('2014-02-01 00:00Z'),
2191+
np.datetime64('2014-03-01 00:00Z'),
2192+
np.datetime64('nat'), np.datetime64('nat'),
2193+
np.datetime64('2014-06-01 00:00Z'),
2194+
np.datetime64('2014-07-01 00:00Z')])
2195+
2196+
if _np_version_under1p7:
2197+
# cannot test array because np.datetime('nat') returns today's date
2198+
cases = [(fidx1, fidx2), (didx1, didx2)]
2199+
else:
2200+
cases = [(fidx1, fidx2), (didx1, didx2), (didx1, darr)]
2201+
2202+
# Check pd.NaT is handles as the same as np.nan
2203+
for idx1, idx2 in cases:
2204+
result = idx1 < idx2
2205+
expected = np.array([True, False, False, False, True, False])
2206+
self.assert_numpy_array_equal(result, expected)
2207+
result = idx2 > idx1
2208+
expected = np.array([True, False, False, False, True, False])
2209+
self.assert_numpy_array_equal(result, expected)
2210+
2211+
result = idx1 <= idx2
2212+
expected = np.array([True, False, False, False, True, True])
2213+
self.assert_numpy_array_equal(result, expected)
2214+
result = idx2 >= idx1
2215+
expected = np.array([True, False, False, False, True, True])
2216+
self.assert_numpy_array_equal(result, expected)
2217+
2218+
result = idx1 == idx2
2219+
expected = np.array([False, False, False, False, False, True])
2220+
self.assert_numpy_array_equal(result, expected)
2221+
2222+
result = idx1 != idx2
2223+
expected = np.array([True, True, True, True, True, False])
2224+
self.assert_numpy_array_equal(result, expected)
2225+
2226+
for idx1, val in [(fidx1, np.nan), (didx1, pd.NaT)]:
2227+
result = idx1 < val
2228+
expected = np.array([False, False, False, False, False, False])
2229+
self.assert_numpy_array_equal(result, expected)
2230+
result = idx1 > val
2231+
self.assert_numpy_array_equal(result, expected)
2232+
2233+
result = idx1 <= val
2234+
self.assert_numpy_array_equal(result, expected)
2235+
result = idx1 >= val
2236+
self.assert_numpy_array_equal(result, expected)
2237+
2238+
result = idx1 == val
2239+
self.assert_numpy_array_equal(result, expected)
2240+
2241+
result = idx1 != val
2242+
expected = np.array([True, True, True, True, True, True])
2243+
self.assert_numpy_array_equal(result, expected)
2244+
2245+
# Check pd.NaT is handles as the same as np.nan
2246+
for idx1, val in [(fidx1, 3), (didx1, datetime(2014, 3, 1))]:
2247+
result = idx1 < val
2248+
expected = np.array([True, False, False, False, False, False])
2249+
self.assert_numpy_array_equal(result, expected)
2250+
result = idx1 > val
2251+
expected = np.array([False, False, False, False, True, True])
2252+
self.assert_numpy_array_equal(result, expected)
2253+
2254+
result = idx1 <= val
2255+
expected = np.array([True, False, True, False, False, False])
2256+
self.assert_numpy_array_equal(result, expected)
2257+
result = idx1 >= val
2258+
expected = np.array([False, False, True, False, True, True])
2259+
self.assert_numpy_array_equal(result, expected)
2260+
2261+
result = idx1 == val
2262+
expected = np.array([False, False, True, False, False, False])
2263+
self.assert_numpy_array_equal(result, expected)
2264+
2265+
result = idx1 != val
2266+
expected = np.array([True, True, False, True, True, True])
2267+
self.assert_numpy_array_equal(result, expected)
2268+
21822269
def test_map(self):
21832270
rng = date_range('1/1/2000', periods=10)
21842271

0 commit comments

Comments
 (0)