@@ -34,18 +34,50 @@ except ImportError:
34
34
from numpy .core ._multiarray_tests import internal_overlap
35
35
36
36
from libc .string cimport memcpy
37
+ cimport cpython .pycapsule
38
+ from cpython .exc cimport (PyErr_Occurred , PyErr_Clear )
39
+ from cpython .mem cimport (PyMem_Malloc , PyMem_Free )
40
+
41
+ from threading import local as threading_local
42
+
43
+ # thread-local storage
44
+ _tls = threading_local ()
45
+
46
+ cdef const char * capsule_name = "dfti_cache"
47
+
48
+ cdef void _capsule_destructor (object caps ):
49
+ cdef DftiCache * _cache = NULL
50
+ cdef int status = 0
51
+ if (caps is None ):
52
+ print ("Nothing to destroy" )
53
+ return
54
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (caps , capsule_name )
55
+ status = _free_dfti_cache (_cache )
56
+ PyMem_Free (_cache )
57
+ if (status != 0 ):
58
+ raise ValueError ("Internal Error: Freeing DFTI Cache returned with error = {}" .format (status ))
59
+
60
+
61
+ def _tls_dfti_cache_capsule ():
62
+ cdef DftiCache * _cache_struct
63
+
64
+ init = getattr (_tls , 'initialized' , None )
65
+ if (init is None ):
66
+ _cache_struct = < DftiCache * > PyMem_Malloc (sizeof (DftiCache ));
67
+ # important to initialized
68
+ _cache_struct .initialized = 0
69
+ _cache_struct .hand = NULL
70
+ _tls .initialized = True
71
+ _tls .capsule = cpython .pycapsule .PyCapsule_New (< void * > _cache_struct , capsule_name , & _capsule_destructor )
72
+ capsule = getattr (_tls , 'capsule' , None )
73
+ if (not cpython .pycapsule .PyCapsule_IsValid (capsule , capsule_name )):
74
+ raise ValueError ("Internal Error: invalid capsule stored in TLS" )
75
+ return capsule
37
76
38
- from threading import Lock
39
- _lock = Lock ()
40
77
41
78
cdef extern from "Python.h" :
42
79
ctypedef int size_t
43
80
44
- void * PyMem_Malloc (size_t n )
45
- void PyMem_Free (void * buf )
46
-
47
- int PyErr_Occurred ()
48
- void PyErr_Clear ()
49
81
long PyInt_AsLong (object ob )
50
82
int PyObject_HasAttrString (object , char * )
51
83
@@ -58,32 +90,36 @@ cdef extern from *:
58
90
object PyArray_BASE (cnp .ndarray )
59
91
60
92
cdef extern from "src/mklfft.h" :
61
- int cdouble_mkl_fft1d_in (cnp .ndarray , int , int )
62
- int cfloat_mkl_fft1d_in (cnp .ndarray , int , int )
63
- int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
64
- int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
65
- int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
66
- int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
67
-
68
- int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int )
69
- int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int )
70
- int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
71
- int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
72
- int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
73
- int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
74
-
75
- int double_mkl_rfft_in (cnp .ndarray , int , int )
76
- int double_mkl_irfft_in (cnp .ndarray , int , int )
77
- int float_mkl_rfft_in (cnp .ndarray , int , int )
78
- int float_mkl_irfft_in (cnp .ndarray , int , int )
79
-
80
- int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
81
- int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
82
- int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
83
- int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
84
-
85
- int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
86
- int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
93
+ cdef struct DftiCache :
94
+ void * hand
95
+ int initialized
96
+ int _free_dfti_cache (DftiCache * )
97
+ int cdouble_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
98
+ int cfloat_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
99
+ int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
100
+ int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
101
+ int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
102
+ int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
103
+
104
+ int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
105
+ int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
106
+ int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
107
+ int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarra , DftiCache * )
108
+ int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
109
+ int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
110
+
111
+ int double_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
112
+ int double_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
113
+ int float_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
114
+ int float_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
115
+
116
+ int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
117
+ int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
118
+ int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
119
+ int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
120
+
121
+ int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
122
+ int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
87
123
88
124
int cdouble_cdouble_mkl_fftnd_in (cnp .ndarray )
89
125
int cdouble_cdouble_mkl_ifftnd_in (cnp .ndarray )
@@ -268,6 +304,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
268
304
cdef int ALL_HARMONICS = 1
269
305
cdef char * c_error_msg = NULL
270
306
cdef bytes py_error_msg
307
+ cdef DftiCache * _cache
271
308
272
309
x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
273
310
& axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
@@ -295,19 +332,20 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
295
332
in_place = 1
296
333
297
334
if in_place :
298
- with _lock :
299
- if x_type is cnp .NPY_CDOUBLE :
300
- if dir_ < 0 :
301
- status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
302
- else :
303
- status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
304
- elif x_type is cnp .NPY_CFLOAT :
305
- if dir_ < 0 :
306
- status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
307
- else :
308
- status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
335
+ _cache_capsule = _tls_dfti_cache_capsule ()
336
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
337
+ if x_type is cnp .NPY_CDOUBLE :
338
+ if dir_ < 0 :
339
+ status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
340
+ else :
341
+ status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
342
+ elif x_type is cnp .NPY_CFLOAT :
343
+ if dir_ < 0 :
344
+ status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
309
345
else :
310
- status = 1
346
+ status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
347
+ else :
348
+ status = 1
311
349
312
350
if status :
313
351
c_error_msg = mkl_dfti_error (status )
@@ -327,37 +365,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
327
365
f_arr = __allocate_result (x_arr , n_ , axis_ , f_type );
328
366
329
367
# call out-of-place FFT
330
- with _lock :
331
- if f_type is cnp .NPY_CDOUBLE :
332
- if x_type is cnp .NPY_DOUBLE :
333
- if dir_ < 0 :
334
- status = double_cdouble_mkl_ifft1d_out (
335
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
336
- else :
337
- status = double_cdouble_mkl_fft1d_out (
338
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
339
- elif x_type is cnp .NPY_CDOUBLE :
340
- if dir_ < 0 :
341
- status = cdouble_cdouble_mkl_ifft1d_out (
342
- x_arr , n_ , < int > axis_ , f_arr )
343
- else :
344
- status = cdouble_cdouble_mkl_fft1d_out (
345
- x_arr , n_ , < int > axis_ , f_arr )
346
- else :
347
- if x_type is cnp .NPY_FLOAT :
348
- if dir_ < 0 :
349
- status = float_cfloat_mkl_ifft1d_out (
350
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
351
- else :
352
- status = float_cfloat_mkl_fft1d_out (
353
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
354
- elif x_type is cnp .NPY_CFLOAT :
355
- if dir_ < 0 :
356
- status = cfloat_cfloat_mkl_ifft1d_out (
357
- x_arr , n_ , < int > axis_ , f_arr )
358
- else :
359
- status = cfloat_cfloat_mkl_fft1d_out (
360
- x_arr , n_ , < int > axis_ , f_arr )
368
+ _cache_capsule = _tls_dfti_cache_capsule ()
369
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
370
+ if f_type is cnp .NPY_CDOUBLE :
371
+ if x_type is cnp .NPY_DOUBLE :
372
+ if dir_ < 0 :
373
+ status = double_cdouble_mkl_ifft1d_out (
374
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
375
+ else :
376
+ status = double_cdouble_mkl_fft1d_out (
377
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
378
+ elif x_type is cnp .NPY_CDOUBLE :
379
+ if dir_ < 0 :
380
+ status = cdouble_cdouble_mkl_ifft1d_out (
381
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
382
+ else :
383
+ status = cdouble_cdouble_mkl_fft1d_out (
384
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
385
+ else :
386
+ if x_type is cnp .NPY_FLOAT :
387
+ if dir_ < 0 :
388
+ status = float_cfloat_mkl_ifft1d_out (
389
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
390
+ else :
391
+ status = float_cfloat_mkl_fft1d_out (
392
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
393
+ elif x_type is cnp .NPY_CFLOAT :
394
+ if dir_ < 0 :
395
+ status = cfloat_cfloat_mkl_ifft1d_out (
396
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
397
+ else :
398
+ status = cfloat_cfloat_mkl_fft1d_out (
399
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
361
400
362
401
if (status ):
363
402
c_error_msg = mkl_dfti_error (status )
@@ -388,6 +427,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
388
427
cdef int x_type , status
389
428
cdef char * c_error_msg = NULL
390
429
cdef bytes py_error_msg
430
+ cdef DftiCache * _cache
391
431
392
432
x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
393
433
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -413,19 +453,20 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
413
453
in_place = 1
414
454
415
455
if in_place :
416
- with _lock :
417
- if x_type is cnp .NPY_DOUBLE :
418
- if dir_ < 0 :
419
- status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ )
420
- else :
421
- status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ )
422
- elif x_type is cnp .NPY_FLOAT :
423
- if dir_ < 0 :
424
- status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ )
425
- else :
426
- status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ )
456
+ _cache_capsule = _tls_dfti_cache_capsule ()
457
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
458
+ if x_type is cnp .NPY_DOUBLE :
459
+ if dir_ < 0 :
460
+ status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
461
+ else :
462
+ status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
463
+ elif x_type is cnp .NPY_FLOAT :
464
+ if dir_ < 0 :
465
+ status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
427
466
else :
428
- status = 1
467
+ status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
468
+ else :
469
+ status = 1
429
470
430
471
if status :
431
472
c_error_msg = mkl_dfti_error (status )
@@ -443,17 +484,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
443
484
f_arr = __allocate_result (x_arr , n_ , axis_ , x_type );
444
485
445
486
# call out-of-place FFT
446
- with _lock :
447
- if x_type is cnp .NPY_DOUBLE :
448
- if dir_ < 0 :
449
- status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
450
- else :
451
- status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
487
+ _cache_capsule = _tls_dfti_cache_capsule ()
488
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
489
+ if x_type is cnp .NPY_DOUBLE :
490
+ if dir_ < 0 :
491
+ status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
452
492
else :
453
- if dir_ < 0 :
454
- status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
455
- else :
456
- status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
493
+ status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
494
+ else :
495
+ if dir_ < 0 :
496
+ status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
497
+ else :
498
+ status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
457
499
458
500
if (status ):
459
501
c_error_msg = mkl_dfti_error (status )
@@ -479,6 +521,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
479
521
cdef int direction = 1 # dummy, only used for the sake of arg-processing
480
522
cdef char * c_error_msg = NULL
481
523
cdef bytes py_error_msg
524
+ cdef DftiCache * _cache
482
525
483
526
x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
484
527
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -509,11 +552,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
509
552
510
553
# call out-of-place FFT
511
554
if x_type is cnp .NPY_FLOAT :
512
- with _lock :
513
- status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
555
+ _cache_capsule = _tls_dfti_cache_capsule ()
556
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
557
+ status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
514
558
else :
515
- with _lock :
516
- status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
559
+ _cache_capsule = _tls_dfti_cache_capsule ()
560
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
561
+ status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
517
562
518
563
if (status ):
519
564
c_error_msg = mkl_dfti_error (status )
@@ -553,6 +598,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
553
598
cdef int direction = 1 # dummy, only used for the sake of arg-processing
554
599
cdef char * c_error_msg = NULL
555
600
cdef bytes py_error_msg
601
+ cdef DftiCache * _cache
556
602
557
603
int_n = _is_integral (n )
558
604
# nn gives the number elements along axis of the input that we use
@@ -591,11 +637,13 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
591
637
592
638
# call out-of-place FFT
593
639
if x_type is cnp .NPY_CFLOAT :
594
- with _lock :
595
- status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
640
+ _cache_capsule = _tls_dfti_cache_capsule ()
641
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
642
+ status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
596
643
else :
597
- with _lock :
598
- status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
644
+ _cache_capsule = _tls_dfti_cache_capsule ()
645
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
646
+ status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
599
647
600
648
if (status ):
601
649
c_error_msg = mkl_dfti_error (status )
0 commit comments