Skip to content

Commit 95223b0

Browse files
Merge pull request #6 from IntelPython/update-with-forthcoming-numpy-changes
Update with forthcoming numpy changes
2 parents 5262d8d + d4cd300 commit 95223b0

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626

2727
import numpy as np
2828
cimport numpy as cnp
29-
from numpy.core.multiarray_tests import internal_overlap
29+
try:
30+
from numpy.core.multiarray_tests import internal_overlap
31+
except ModuleNotFoundError:
32+
# Module has been renamed in NumPy 1.15
33+
from numpy.core._multiarray_tests import internal_overlap
3034

3135
from libc.string cimport memcpy
3236

@@ -101,12 +105,12 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
101105
Strict check for `arr` not sharing any data with `original`,
102106
under the assumption that arr = asarray(original)
103107
"""
104-
if arr is orig:
105-
return 0
106108
if not cnp.PyArray_Check(orig) and PyObject_HasAttrString(orig, '__array__'):
107109
return 0
110+
if isinstance(orig, np.ndarray) and (arr is (<cnp.ndarray> orig)):
111+
return 0
108112
arr_obj = <object> arr
109-
return 1 if arr_obj.base is None else 0
113+
return 1 if (arr_obj.base is None) else 0
110114

111115

112116
def fft(x, n=None, axis=-1, overwrite_x=False):
@@ -808,7 +812,7 @@ def rfftn_numpy(x, s=None, axes=None):
808812
no_trim = (s is None) and (axes is None)
809813
s, axes = _cook_nd_args(a, s, axes)
810814
la = axes[-1]
811-
# trim array, so that rfft_numpy avoid doing
815+
# trim array, so that rfft_numpy avoids doing
812816
# unnecessary computations
813817
if not no_trim:
814818
a = _trim_array(a, s, axes)
@@ -843,15 +847,18 @@ def irfftn_numpy(x, s=None, axes=None):
843847
a = _fix_dimensions(a, s, axes)
844848
ovr_x = True if _datacopied(<cnp.ndarray> a, x) else False
845849
if len(set(axes)) == len(axes) and len(axes) == a.ndim and len(axes) > 2:
850+
# due to need to write into a, we must copy
851+
if not ovr_x:
852+
a = a.copy()
853+
ovr_x = True
846854
ss, aa = _remove_axis(s, axes, la)
847855
ind = [slice(None,None,1),] * len(s)
848856
for ii in range(a.shape[la]):
849857
ind[la] = ii
850858
tind = tuple(ind)
851859
a[tind] = _fftnd_impl(
852860
a[tind], shape=ss, axes=aa,
853-
overwrite_x=ovr_x, direction=-1)
854-
ovr_x = True
861+
overwrite_x=True, direction=-1)
855862
else:
856863
for ii in range(len(axes)-1):
857864
a = ifft(a, s[ii], axes[ii], overwrite_x=ovr_x)

0 commit comments

Comments
 (0)