Skip to content

Fix MKL FFT descriptor corruption in threaded Python scripts #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 76 additions & 65 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ except ModuleNotFoundError:

from libc.string cimport memcpy

from threading import Lock
_lock = Lock()

cdef extern from "Python.h":
ctypedef int size_t

Expand Down Expand Up @@ -289,18 +292,19 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
in_place = 1

if in_place:
if x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
else:
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
with _lock:
if x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
else:
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
else:
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
else:
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
else:
status = 1
status = 1

if status:
raise ValueError("Internal error, status={}".format(status))
Expand All @@ -318,36 +322,37 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
f_arr = __allocate_result(x_arr, n_, axis_, f_type);

# call out-of-place FFT
if f_type is cnp.NPY_CDOUBLE:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
else:
status = double_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
elif x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
status = cdouble_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
if x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
else:
status = float_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
status = cfloat_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr)
with _lock:
if f_type is cnp.NPY_CDOUBLE:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
else:
status = double_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
elif x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
status = cdouble_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
if x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
else:
status = float_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
status = cfloat_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr)

if (status):
raise ValueError("Internal error occurred, status={}".format(status))
Expand Down Expand Up @@ -399,18 +404,19 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
in_place = 1

if in_place:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
else:
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
elif x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
with _lock:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
else:
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
elif x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
else:
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
else:
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
else:
status = 1
status = 1

if status:
raise ValueError("Internal error, status={}".format(status))
Expand All @@ -426,16 +432,17 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
f_arr = __allocate_result(x_arr, n_, axis_, x_type);

# call out-of-place FFT
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
else:
if dir_ < 0:
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
with _lock:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
if dir_ < 0:
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)

if (status):
raise ValueError("Internal error occurred, status={}".format(status))
Expand Down Expand Up @@ -487,9 +494,11 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):

# call out-of-place FFT
if x_type is cnp.NPY_FLOAT:
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
with _lock:
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
else:
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
with _lock:
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)

if (status):
raise ValueError("Internal error occurred, with status={}".format(status))
Expand Down Expand Up @@ -563,9 +572,11 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):

# call out-of-place FFT
if x_type is cnp.NPY_CFLOAT:
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
with _lock:
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
with _lock:
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)

if (status):
raise ValueError("Internal error occurred, status={}".format(status))
Expand Down
1 change: 0 additions & 1 deletion mkl_fft/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def configuration(parent_package='',top_path=None):

config.add_extension(
name = '_pydfti',
# module_name = 'mkl_fft._pydfti',
sources = [
join(wdir, 'mklfft.c.src'),
join(wdir, 'multi_iter.c'),
Expand Down