Skip to content

Commit d4cd300

Browse files
BUG: Fixed issue #4
Because assignment `a[tind] = func(...)` amounts to overwriting data in array `a`, a copy is needed, if not already made. ```ipython In [1]: import numpy as np, mkl_fft In [2]: x = np.random.randn(4,4,4) In [3]: xc = x.copy() In [4]: y = mkl_fft._numpy_fft.rfftn(x) In [5]: yc = y.copy() In [6]: z = mkl_fft._numpy_fft.irfftn(y) In [7]: np.allclose(y, yc) Out[7]: True In [8]: np.allclose(z, x) Out[8]: True ```
1 parent 21ef9d4 commit d4cd300

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,12 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
105105
Strict check for `arr` not sharing any data with `original`,
106106
under the assumption that arr = asarray(original)
107107
"""
108-
if arr is orig:
109-
return 0
110108
if not cnp.PyArray_Check(orig) and PyObject_HasAttrString(orig, '__array__'):
111109
return 0
110+
if isinstance(orig, np.ndarray) and (arr is (<cnp.ndarray> orig)):
111+
return 0
112112
arr_obj = <object> arr
113-
return 1 if arr_obj.base is None else 0
113+
return 1 if (arr_obj.base is None) else 0
114114

115115

116116
def fft(x, n=None, axis=-1, overwrite_x=False):
@@ -812,7 +812,7 @@ def rfftn_numpy(x, s=None, axes=None):
812812
no_trim = (s is None) and (axes is None)
813813
s, axes = _cook_nd_args(a, s, axes)
814814
la = axes[-1]
815-
# trim array, so that rfft_numpy avoid doing
815+
# trim array, so that rfft_numpy avoids doing
816816
# unnecessary computations
817817
if not no_trim:
818818
a = _trim_array(a, s, axes)
@@ -847,15 +847,18 @@ def irfftn_numpy(x, s=None, axes=None):
847847
a = _fix_dimensions(a, s, axes)
848848
ovr_x = True if _datacopied(<cnp.ndarray> a, x) else False
849849
if len(set(axes)) == len(axes) and len(axes) == a.ndim and len(axes) > 2:
850+
# due to need to write into a, we must copy
851+
if not ovr_x:
852+
a = a.copy()
853+
ovr_x = True
850854
ss, aa = _remove_axis(s, axes, la)
851855
ind = [slice(None,None,1),] * len(s)
852856
for ii in range(a.shape[la]):
853857
ind[la] = ii
854858
tind = tuple(ind)
855859
a[tind] = _fftnd_impl(
856860
a[tind], shape=ss, axes=aa,
857-
overwrite_x=ovr_x, direction=-1)
858-
ovr_x = True
861+
overwrite_x=True, direction=-1)
859862
else:
860863
for ii in range(len(axes)-1):
861864
a = ifft(a, s[ii], axes[ii], overwrite_x=ovr_x)

0 commit comments

Comments
 (0)