37
37
38
38
"""
39
39
40
+ import math
41
+
40
42
import dpctl .tensor as dpt
41
43
import dpctl .utils as dpu
42
44
import numpy
64
66
dpnp_cov ,
65
67
)
66
68
69
+ min_ = min # pylint: disable=used-before-assignment
70
+
67
71
__all__ = [
68
72
"amax" ,
69
73
"amin" ,
@@ -482,17 +486,57 @@ def _get_padding(a_size, v_size, mode):
482
486
return l_pad , r_pad
483
487
484
488
485
- def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad ):
489
+ def _choose_conv_method (a , v , rdtype ):
490
+ assert a .size >= v .size
491
+ if rdtype == dpnp .bool :
492
+ return "direct"
493
+
494
+ if v .size < 10 ** 4 or a .size < 10 ** 4 :
495
+ return "direct"
496
+
497
+ if dpnp .issubdtype (rdtype , dpnp .integer ):
498
+ max_a = int (dpnp .max (dpnp .abs (a )))
499
+ sum_v = int (dpnp .sum (dpnp .abs (v )))
500
+ max_value = int (max_a * sum_v )
501
+
502
+ default_float = dpnp .default_float_type (a .sycl_device )
503
+ if max_value > 2 ** numpy .finfo (default_float ).nmant - 1 :
504
+ return "direct"
505
+
506
+ if dpnp .issubdtype (rdtype , dpnp .number ):
507
+ return "fft"
508
+
509
+ raise ValueError (f"Unsupported dtype: { rdtype } " )
510
+
511
+
512
+ def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype ):
486
513
queue = a .sycl_queue
514
+ device = a .sycl_device
515
+
516
+ supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
517
+ supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
487
518
488
- usm_type = dpu .get_coerced_usm_type ([a .usm_type , v .usm_type ])
489
- out_size = l_pad + r_pad + a .size - v .size + 1
519
+ if supported_dtype is None :
520
+ raise ValueError (
521
+ f"Unsupported input types ({ a .dtype } , { v .dtype } ), "
522
+ "and the inputs could not be coerced to any "
523
+ f"supported types. List of supported types: { supported_types } "
524
+ )
525
+
526
+ a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
527
+ v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
528
+
529
+ usm_type = dpu .get_coerced_usm_type ([a_casted .usm_type , v_casted .usm_type ])
530
+ out_size = l_pad + r_pad + a_casted .size - v_casted .size + 1
490
531
out = dpnp .empty (
491
- shape = out_size , sycl_queue = queue , dtype = a .dtype , usm_type = usm_type
532
+ shape = out_size ,
533
+ sycl_queue = queue ,
534
+ dtype = supported_dtype ,
535
+ usm_type = usm_type ,
492
536
)
493
537
494
- a_usm = dpnp .get_usm_ndarray (a )
495
- v_usm = dpnp .get_usm_ndarray (v )
538
+ a_usm = dpnp .get_usm_ndarray (a_casted )
539
+ v_usm = dpnp .get_usm_ndarray (v_casted )
496
540
out_usm = dpnp .get_usm_ndarray (out )
497
541
498
542
_manager = dpu .SequentialOrderManager [queue ]
@@ -510,7 +554,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
510
554
return out
511
555
512
556
513
- def correlate (a , v , mode = "valid" ):
557
+ def _convolve_fft (a , v , l_pad , r_pad , rtype ):
558
+ assert a .size >= v .size
559
+ assert l_pad < v .size
560
+
561
+ # +1 is needed to avoid circular convolution
562
+ padded_size = a .size + r_pad + 1
563
+ fft_size = 2 ** math .ceil (math .log2 (padded_size ))
564
+
565
+ af = dpnp .fft .fft (a , fft_size ) # pylint: disable=no-member
566
+ vf = dpnp .fft .fft (v , fft_size ) # pylint: disable=no-member
567
+
568
+ r = dpnp .fft .ifft (af * vf ) # pylint: disable=no-member
569
+ if dpnp .issubdtype (rtype , dpnp .floating ):
570
+ r = r .real
571
+ elif dpnp .issubdtype (rtype , dpnp .integer ) or rtype == dpnp .bool :
572
+ r = r .real .round ()
573
+
574
+ start = v .size - 1 - l_pad
575
+ end = padded_size - 1
576
+
577
+ return r [start :end ]
578
+
579
+
580
+ def correlate (a , v , mode = "valid" , method = "auto" ):
514
581
r"""
515
582
Cross-correlation of two 1-dimensional sequences.
516
583
@@ -535,6 +602,20 @@ def correlate(a, v, mode="valid"):
535
602
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
536
603
537
604
Default: ``'valid'``.
605
+ method : {'auto', 'direct', 'fft'}, optional
606
+ `'direct'`: The correlation is determined directly from sums.
607
+
608
+ `'fft'`: The Fourier Transform is used to perform the calculations.
609
+ This method is faster for long sequences but can have accuracy issues.
610
+
611
+ `'auto'`: Automatically chooses direct or Fourier method based on
612
+ an estimate of which is faster.
613
+
614
+ Note: Use of the FFT convolution on input containing NAN or INF
615
+ will lead to the entire output being NAN or INF.
616
+ Use method='direct' when your input contains NAN or INF values.
617
+
618
+ Default: ``'auto'``.
538
619
539
620
Notes
540
621
-----
@@ -560,7 +641,6 @@ def correlate(a, v, mode="valid"):
560
641
:obj:`dpnp.convolve` : Discrete, linear convolution of two
561
642
one-dimensional sequences.
562
643
563
-
564
644
Examples
565
645
--------
566
646
>>> import dpnp as np
@@ -602,19 +682,14 @@ def correlate(a, v, mode="valid"):
602
682
f"Received shapes: a.shape={ a .shape } , v.shape={ v .shape } "
603
683
)
604
684
605
- supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
685
+ supported_methods = ["auto" , "direct" , "fft" ]
686
+ if method not in supported_methods :
687
+ raise ValueError (
688
+ f"Unknown method: { method } . Supported methods: { supported_methods } "
689
+ )
606
690
607
691
device = a .sycl_device
608
692
rdtype = result_type_for_device ([a .dtype , v .dtype ], device )
609
- supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
610
-
611
- if supported_dtype is None :
612
- raise ValueError (
613
- f"function '{ correlate } ' does not support input types "
614
- f"({ a .dtype } , { v .dtype } ), "
615
- "and the inputs could not be coerced to any "
616
- f"supported types. List of supported types: { supported_types } "
617
- )
618
693
619
694
if dpnp .issubdtype (v .dtype , dpnp .complexfloating ):
620
695
v = dpnp .conj (v )
@@ -626,13 +701,15 @@ def correlate(a, v, mode="valid"):
626
701
627
702
l_pad , r_pad = _get_padding (a .size , v .size , mode )
628
703
629
- a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
630
- v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
631
-
632
- if v .size > a .size :
633
- a_casted , v_casted = v_casted , a_casted
704
+ if method == "auto" :
705
+ method = _choose_conv_method (a , v , rdtype )
634
706
635
- r = _run_native_sliding_dot_product1d (a_casted , v_casted , l_pad , r_pad )
707
+ if method == "direct" :
708
+ r = _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype )
709
+ elif method == "fft" :
710
+ r = _convolve_fft (a , v [::- 1 ], l_pad , r_pad , rdtype )
711
+ else :
712
+ raise ValueError (f"Unknown method: { method } " )
636
713
637
714
if revert :
638
715
r = r [::- 1 ]
0 commit comments