Skip to content

Commit aeb56b7

Browse files
Merge pull request #44 from IntelPython/feature/tls-dfti-cache
Feature/tls dfti cache
2 parents b94dc80 + ceb7b3f commit aeb56b7

File tree

3 files changed

+297
-254
lines changed

3 files changed

+297
-254
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 154 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,50 @@ except ImportError:
3434
from numpy.core._multiarray_tests import internal_overlap
3535

3636
from libc.string cimport memcpy
37+
cimport cpython.pycapsule
38+
from cpython.exc cimport (PyErr_Occurred, PyErr_Clear)
39+
from cpython.mem cimport (PyMem_Malloc, PyMem_Free)
40+
41+
from threading import local as threading_local
42+
43+
# thread-local storage
44+
_tls = threading_local()
45+
46+
cdef const char *capsule_name = "dfti_cache"
47+
48+
cdef void _capsule_destructor(object caps):
49+
cdef DftiCache *_cache = NULL
50+
cdef int status = 0
51+
if (caps is None):
52+
print("Nothing to destroy")
53+
return
54+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
55+
status = _free_dfti_cache(_cache)
56+
PyMem_Free(_cache)
57+
if (status != 0):
58+
raise ValueError("Internal Error: Freeing DFTI Cache returned with error = {}".format(status))
59+
60+
61+
def _tls_dfti_cache_capsule():
62+
cdef DftiCache *_cache_struct
63+
64+
init = getattr(_tls, 'initialized', None)
65+
if (init is None):
66+
_cache_struct = <DftiCache *> PyMem_Malloc(sizeof(DftiCache));
67+
# important to initialized
68+
_cache_struct.initialized = 0
69+
_cache_struct.hand = NULL
70+
_tls.initialized = True
71+
_tls.capsule = cpython.pycapsule.PyCapsule_New(<void *>_cache_struct, capsule_name, &_capsule_destructor)
72+
capsule = getattr(_tls, 'capsule', None)
73+
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
74+
raise ValueError("Internal Error: invalid capsule stored in TLS")
75+
return capsule
3776

38-
from threading import Lock
39-
_lock = Lock()
4077

4178
cdef extern from "Python.h":
4279
ctypedef int size_t
4380

44-
void* PyMem_Malloc(size_t n)
45-
void PyMem_Free(void* buf)
46-
47-
int PyErr_Occurred()
48-
void PyErr_Clear()
4981
long PyInt_AsLong(object ob)
5082
int PyObject_HasAttrString(object, char*)
5183

@@ -58,32 +90,36 @@ cdef extern from *:
5890
object PyArray_BASE(cnp.ndarray)
5991

6092
cdef extern from "src/mklfft.h":
61-
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int)
62-
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int)
63-
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
64-
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray)
65-
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
66-
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray)
67-
68-
int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int)
69-
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int)
70-
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
71-
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray)
72-
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
73-
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray)
74-
75-
int double_mkl_rfft_in(cnp.ndarray, int, int)
76-
int double_mkl_irfft_in(cnp.ndarray, int, int)
77-
int float_mkl_rfft_in(cnp.ndarray, int, int)
78-
int float_mkl_irfft_in(cnp.ndarray, int, int)
79-
80-
int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray)
81-
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
82-
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray)
83-
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
84-
85-
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
86-
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
93+
cdef struct DftiCache:
94+
void * hand
95+
int initialized
96+
int _free_dfti_cache(DftiCache *)
97+
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
98+
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
99+
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
100+
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
101+
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
102+
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
103+
104+
int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
105+
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
106+
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
107+
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarra, DftiCache*)
108+
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
109+
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
110+
111+
int double_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
112+
int double_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
113+
int float_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
114+
int float_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
115+
116+
int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
117+
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
118+
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
119+
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
120+
121+
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
122+
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
87123

88124
int cdouble_cdouble_mkl_fftnd_in(cnp.ndarray)
89125
int cdouble_cdouble_mkl_ifftnd_in(cnp.ndarray)
@@ -268,6 +304,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
268304
cdef int ALL_HARMONICS = 1
269305
cdef char * c_error_msg = NULL
270306
cdef bytes py_error_msg
307+
cdef DftiCache *_cache
271308

272309
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
273310
&axis_, &n_, &in_place, &xnd, &dir_, 0)
@@ -295,19 +332,20 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
295332
in_place = 1
296333

297334
if in_place:
298-
with _lock:
299-
if x_type is cnp.NPY_CDOUBLE:
300-
if dir_ < 0:
301-
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
302-
else:
303-
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
304-
elif x_type is cnp.NPY_CFLOAT:
305-
if dir_ < 0:
306-
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
307-
else:
308-
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
335+
_cache_capsule = _tls_dfti_cache_capsule()
336+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
337+
if x_type is cnp.NPY_CDOUBLE:
338+
if dir_ < 0:
339+
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
340+
else:
341+
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
342+
elif x_type is cnp.NPY_CFLOAT:
343+
if dir_ < 0:
344+
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
309345
else:
310-
status = 1
346+
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
347+
else:
348+
status = 1
311349

312350
if status:
313351
c_error_msg = mkl_dfti_error(status)
@@ -327,37 +365,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
327365
f_arr = __allocate_result(x_arr, n_, axis_, f_type);
328366

329367
# call out-of-place FFT
330-
with _lock:
331-
if f_type is cnp.NPY_CDOUBLE:
332-
if x_type is cnp.NPY_DOUBLE:
333-
if dir_ < 0:
334-
status = double_cdouble_mkl_ifft1d_out(
335-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
336-
else:
337-
status = double_cdouble_mkl_fft1d_out(
338-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
339-
elif x_type is cnp.NPY_CDOUBLE:
340-
if dir_ < 0:
341-
status = cdouble_cdouble_mkl_ifft1d_out(
342-
x_arr, n_, <int> axis_, f_arr)
343-
else:
344-
status = cdouble_cdouble_mkl_fft1d_out(
345-
x_arr, n_, <int> axis_, f_arr)
346-
else:
347-
if x_type is cnp.NPY_FLOAT:
348-
if dir_ < 0:
349-
status = float_cfloat_mkl_ifft1d_out(
350-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
351-
else:
352-
status = float_cfloat_mkl_fft1d_out(
353-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
354-
elif x_type is cnp.NPY_CFLOAT:
355-
if dir_ < 0:
356-
status = cfloat_cfloat_mkl_ifft1d_out(
357-
x_arr, n_, <int> axis_, f_arr)
358-
else:
359-
status = cfloat_cfloat_mkl_fft1d_out(
360-
x_arr, n_, <int> axis_, f_arr)
368+
_cache_capsule = _tls_dfti_cache_capsule()
369+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
370+
if f_type is cnp.NPY_CDOUBLE:
371+
if x_type is cnp.NPY_DOUBLE:
372+
if dir_ < 0:
373+
status = double_cdouble_mkl_ifft1d_out(
374+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
375+
else:
376+
status = double_cdouble_mkl_fft1d_out(
377+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
378+
elif x_type is cnp.NPY_CDOUBLE:
379+
if dir_ < 0:
380+
status = cdouble_cdouble_mkl_ifft1d_out(
381+
x_arr, n_, <int> axis_, f_arr, _cache)
382+
else:
383+
status = cdouble_cdouble_mkl_fft1d_out(
384+
x_arr, n_, <int> axis_, f_arr, _cache)
385+
else:
386+
if x_type is cnp.NPY_FLOAT:
387+
if dir_ < 0:
388+
status = float_cfloat_mkl_ifft1d_out(
389+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
390+
else:
391+
status = float_cfloat_mkl_fft1d_out(
392+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
393+
elif x_type is cnp.NPY_CFLOAT:
394+
if dir_ < 0:
395+
status = cfloat_cfloat_mkl_ifft1d_out(
396+
x_arr, n_, <int> axis_, f_arr, _cache)
397+
else:
398+
status = cfloat_cfloat_mkl_fft1d_out(
399+
x_arr, n_, <int> axis_, f_arr, _cache)
361400

362401
if (status):
363402
c_error_msg = mkl_dfti_error(status)
@@ -388,6 +427,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
388427
cdef int x_type, status
389428
cdef char * c_error_msg = NULL
390429
cdef bytes py_error_msg
430+
cdef DftiCache *_cache
391431

392432
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
393433
&axis_, &n_, &in_place, &xnd, &dir_, 1)
@@ -413,19 +453,20 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
413453
in_place = 1
414454

415455
if in_place:
416-
with _lock:
417-
if x_type is cnp.NPY_DOUBLE:
418-
if dir_ < 0:
419-
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
420-
else:
421-
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
422-
elif x_type is cnp.NPY_FLOAT:
423-
if dir_ < 0:
424-
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
425-
else:
426-
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
456+
_cache_capsule = _tls_dfti_cache_capsule()
457+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
458+
if x_type is cnp.NPY_DOUBLE:
459+
if dir_ < 0:
460+
status = double_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
461+
else:
462+
status = double_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
463+
elif x_type is cnp.NPY_FLOAT:
464+
if dir_ < 0:
465+
status = float_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
427466
else:
428-
status = 1
467+
status = float_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
468+
else:
469+
status = 1
429470

430471
if status:
431472
c_error_msg = mkl_dfti_error(status)
@@ -443,17 +484,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
443484
f_arr = __allocate_result(x_arr, n_, axis_, x_type);
444485

445486
# call out-of-place FFT
446-
with _lock:
447-
if x_type is cnp.NPY_DOUBLE:
448-
if dir_ < 0:
449-
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
450-
else:
451-
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
487+
_cache_capsule = _tls_dfti_cache_capsule()
488+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
489+
if x_type is cnp.NPY_DOUBLE:
490+
if dir_ < 0:
491+
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
452492
else:
453-
if dir_ < 0:
454-
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
455-
else:
456-
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
493+
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
494+
else:
495+
if dir_ < 0:
496+
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
497+
else:
498+
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
457499

458500
if (status):
459501
c_error_msg = mkl_dfti_error(status)
@@ -479,6 +521,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
479521
cdef int direction = 1 # dummy, only used for the sake of arg-processing
480522
cdef char * c_error_msg = NULL
481523
cdef bytes py_error_msg
524+
cdef DftiCache *_cache
482525

483526
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
484527
&axis_, &n_, &in_place, &xnd, &dir_, 1)
@@ -509,11 +552,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
509552

510553
# call out-of-place FFT
511554
if x_type is cnp.NPY_FLOAT:
512-
with _lock:
513-
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
555+
_cache_capsule = _tls_dfti_cache_capsule()
556+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
557+
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
514558
else:
515-
with _lock:
516-
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
559+
_cache_capsule = _tls_dfti_cache_capsule()
560+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
561+
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
517562

518563
if (status):
519564
c_error_msg = mkl_dfti_error(status)
@@ -553,6 +598,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
553598
cdef int direction = 1 # dummy, only used for the sake of arg-processing
554599
cdef char * c_error_msg = NULL
555600
cdef bytes py_error_msg
601+
cdef DftiCache *_cache
556602

557603
int_n = _is_integral(n)
558604
# nn gives the number elements along axis of the input that we use
@@ -591,11 +637,13 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
591637

592638
# call out-of-place FFT
593639
if x_type is cnp.NPY_CFLOAT:
594-
with _lock:
595-
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
640+
_cache_capsule = _tls_dfti_cache_capsule()
641+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
642+
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
596643
else:
597-
with _lock:
598-
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
644+
_cache_capsule = _tls_dfti_cache_capsule()
645+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
646+
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
599647

600648
if (status):
601649
c_error_msg = mkl_dfti_error(status)

0 commit comments

Comments
 (0)