Skip to content

Commit f18bbf2

Browse files
Correalation via fft implementation
1 parent 2514037 commit f18bbf2

File tree

2 files changed

+186
-31
lines changed

2 files changed

+186
-31
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 101 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import math
41+
4042
import dpctl.tensor as dpt
4143
import dpctl.utils as dpu
4244
import numpy
@@ -64,6 +66,8 @@
6466
dpnp_cov,
6567
)
6668

69+
min_ = min # pylint: disable=used-before-assignment
70+
6771
__all__ = [
6872
"amax",
6973
"amin",
@@ -482,17 +486,57 @@ def _get_padding(a_size, v_size, mode):
482486
return l_pad, r_pad
483487

484488

485-
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
489+
def _choose_conv_method(a, v, rdtype):
490+
assert a.size >= v.size
491+
if rdtype == dpnp.bool:
492+
return "direct"
493+
494+
if v.size < 10**4 or a.size < 10**4:
495+
return "direct"
496+
497+
if dpnp.issubdtype(rdtype, dpnp.integer):
498+
max_a = int(dpnp.max(dpnp.abs(a)))
499+
sum_v = int(dpnp.sum(dpnp.abs(v)))
500+
max_value = int(max_a * sum_v)
501+
502+
default_float = dpnp.default_float_type(a.sycl_device)
503+
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
504+
return "direct"
505+
506+
if dpnp.issubdtype(rdtype, dpnp.number):
507+
return "fft"
508+
509+
raise ValueError(f"Unsupported dtype: {rdtype}")
510+
511+
512+
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
486513
queue = a.sycl_queue
514+
device = a.sycl_device
515+
516+
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
517+
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
487518

488-
usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
489-
out_size = l_pad + r_pad + a.size - v.size + 1
519+
if supported_dtype is None:
520+
raise ValueError(
521+
f"Unsupported input types ({a.dtype}, {v.dtype}), "
522+
"and the inputs could not be coerced to any "
523+
f"supported types. List of supported types: {supported_types}"
524+
)
525+
526+
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
527+
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
528+
529+
usm_type = dpu.get_coerced_usm_type([a_casted.usm_type, v_casted.usm_type])
530+
out_size = l_pad + r_pad + a_casted.size - v_casted.size + 1
490531
out = dpnp.empty(
491-
shape=out_size, sycl_queue=queue, dtype=a.dtype, usm_type=usm_type
532+
shape=out_size,
533+
sycl_queue=queue,
534+
dtype=supported_dtype,
535+
usm_type=usm_type,
492536
)
493537

494-
a_usm = dpnp.get_usm_ndarray(a)
495-
v_usm = dpnp.get_usm_ndarray(v)
538+
a_usm = dpnp.get_usm_ndarray(a_casted)
539+
v_usm = dpnp.get_usm_ndarray(v_casted)
496540
out_usm = dpnp.get_usm_ndarray(out)
497541

498542
_manager = dpu.SequentialOrderManager[queue]
@@ -510,7 +554,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
510554
return out
511555

512556

513-
def correlate(a, v, mode="valid"):
557+
def _convolve_fft(a, v, l_pad, r_pad, rtype):
558+
assert a.size >= v.size
559+
assert l_pad < v.size
560+
561+
# +1 is needed to avoid circular convolution
562+
padded_size = a.size + r_pad + 1
563+
fft_size = 2 ** math.ceil(math.log2(padded_size))
564+
565+
af = dpnp.fft.fft(a, fft_size) # pylint: disable=no-member
566+
vf = dpnp.fft.fft(v, fft_size) # pylint: disable=no-member
567+
568+
r = dpnp.fft.ifft(af * vf) # pylint: disable=no-member
569+
if dpnp.issubdtype(rtype, dpnp.floating):
570+
r = r.real
571+
elif dpnp.issubdtype(rtype, dpnp.integer) or rtype == dpnp.bool:
572+
r = r.real.round()
573+
574+
start = v.size - 1 - l_pad
575+
end = padded_size - 1
576+
577+
return r[start:end]
578+
579+
580+
def correlate(a, v, mode="valid", method="auto"):
514581
r"""
515582
Cross-correlation of two 1-dimensional sequences.
516583
@@ -535,6 +602,20 @@ def correlate(a, v, mode="valid"):
535602
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
536603
537604
Default: ``'valid'``.
605+
method : {'auto', 'direct', 'fft'}, optional
606+
`'direct'`: The correlation is determined directly from sums.
607+
608+
`'fft'`: The Fourier Transform is used to perform the calculations.
609+
This method is faster for long sequences but can have accuracy issues.
610+
611+
`'auto'`: Automatically chooses direct or Fourier method based on
612+
an estimate of which is faster.
613+
614+
Note: Use of the FFT convolution on input containing NAN or INF
615+
will lead to the entire output being NAN or INF.
616+
Use method='direct' when your input contains NAN or INF values.
617+
618+
Default: ``'auto'``.
538619
539620
Notes
540621
-----
@@ -560,7 +641,6 @@ def correlate(a, v, mode="valid"):
560641
:obj:`dpnp.convolve` : Discrete, linear convolution of two
561642
one-dimensional sequences.
562643
563-
564644
Examples
565645
--------
566646
>>> import dpnp as np
@@ -602,19 +682,14 @@ def correlate(a, v, mode="valid"):
602682
f"Received shapes: a.shape={a.shape}, v.shape={v.shape}"
603683
)
604684

605-
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
685+
supported_methods = ["auto", "direct", "fft"]
686+
if method not in supported_methods:
687+
raise ValueError(
688+
f"Unknown method: {method}. Supported methods: {supported_methods}"
689+
)
606690

607691
device = a.sycl_device
608692
rdtype = result_type_for_device([a.dtype, v.dtype], device)
609-
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
610-
611-
if supported_dtype is None:
612-
raise ValueError(
613-
f"function '{correlate}' does not support input types "
614-
f"({a.dtype}, {v.dtype}), "
615-
"and the inputs could not be coerced to any "
616-
f"supported types. List of supported types: {supported_types}"
617-
)
618693

619694
if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
620695
v = dpnp.conj(v)
@@ -626,13 +701,15 @@ def correlate(a, v, mode="valid"):
626701

627702
l_pad, r_pad = _get_padding(a.size, v.size, mode)
628703

629-
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
630-
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
631-
632-
if v.size > a.size:
633-
a_casted, v_casted = v_casted, a_casted
704+
if method == "auto":
705+
method = _choose_conv_method(a, v, rdtype)
634706

635-
r = _run_native_sliding_dot_product1d(a_casted, v_casted, l_pad, r_pad)
707+
if method == "direct":
708+
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
709+
elif method == "fft":
710+
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
711+
else:
712+
raise ValueError(f"Unknown method: {method}")
636713

637714
if revert:
638715
r = r[::-1]

dpnp/tests/test_statistics.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -634,21 +634,92 @@ class TestCorrelate:
634634
)
635635
@pytest.mark.parametrize("mode", [None, "full", "valid", "same"])
636636
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
637-
def test_correlate(self, a, v, mode, dtype):
637+
@pytest.mark.parametrize("method", [None, "auto", "direct", "fft"])
638+
def test_correlate(self, a, v, mode, dtype, method):
638639
an = numpy.array(a, dtype=dtype)
639640
vn = numpy.array(v, dtype=dtype)
640641
ad = dpnp.array(an)
641642
vd = dpnp.array(vn)
642643

643-
if mode is None:
644-
expected = numpy.correlate(an, vn)
645-
result = dpnp.correlate(ad, vd)
646-
else:
647-
expected = numpy.correlate(an, vn, mode=mode)
648-
result = dpnp.correlate(ad, vd, mode=mode)
644+
dpnp_kwargs = {}
645+
numpy_kwargs = {}
646+
if mode is not None:
647+
dpnp_kwargs["mode"] = mode
648+
numpy_kwargs["mode"] = mode
649+
if method is not None:
650+
dpnp_kwargs["method"] = method
651+
652+
expected = numpy.correlate(an, vn, **numpy_kwargs)
653+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
649654

650655
assert_dtype_allclose(result, expected)
651656

657+
@pytest.mark.parametrize("a_size", [1, 100, 10000])
658+
@pytest.mark.parametrize("v_size", [1, 100, 10000])
659+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
660+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
661+
@pytest.mark.parametrize("method", ["auto", "direct", "fft"])
662+
def test_correlate_random(self, a_size, v_size, mode, dtype, method):
663+
if dtype == dpnp.bool:
664+
an = numpy.random.rand(a_size) > 0.9
665+
vn = numpy.random.rand(v_size) > 0.9
666+
else:
667+
an = (100 * numpy.random.rand(a_size)).astype(dtype)
668+
vn = (100 * numpy.random.rand(v_size)).astype(dtype)
669+
670+
if dpnp.issubdtype(dtype, dpnp.complexfloating):
671+
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
672+
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)
673+
674+
ad = dpnp.array(an)
675+
vd = dpnp.array(vn)
676+
677+
dpnp_kwargs = {}
678+
numpy_kwargs = {}
679+
if mode is not None:
680+
dpnp_kwargs["mode"] = mode
681+
numpy_kwargs["mode"] = mode
682+
if method is not None:
683+
dpnp_kwargs["method"] = method
684+
685+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
686+
expected = numpy.correlate(an, vn, **numpy_kwargs)
687+
688+
rdtype = result.dtype
689+
if dpnp.issubdtype(rdtype, dpnp.integer):
690+
rdtype = dpnp.default_float_type(ad.device)
691+
692+
if method != "fft" and (
693+
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
694+
):
695+
# For 'direct' and 'auto' methods, we expect exact results for integer types
696+
assert_array_equal(result, expected)
697+
else:
698+
result = result.astype(rdtype)
699+
if method == "direct":
700+
expected = numpy.correlate(an, vn, **numpy_kwargs)
701+
# For 'direct' method we can use standard validation
702+
assert_dtype_allclose(result, expected)
703+
else:
704+
rtol = 1e-3
705+
atol = 1e-10
706+
707+
if rdtype == dpnp.bool:
708+
result = result.astype(dpnp.int32)
709+
rdtype = result.dtype
710+
711+
expected = expected.astype(rdtype)
712+
713+
diff = numpy.abs(result.asnumpy() - expected)
714+
invalid = diff > atol + rtol * numpy.abs(expected)
715+
716+
# When using the 'fft' method, we might encounter outliers.
717+
# This usually happens when the resulting array contains values close to zero.
718+
# For these outliers, the relative error can be significant.
719+
# We can tolerate a few such outliers.
720+
if invalid.sum() > 8:
721+
assert_dtype_allclose(result, expected, factor=1000)
722+
652723
def test_correlate_mode_error(self):
653724
a = dpnp.arange(5)
654725
v = dpnp.arange(3)
@@ -700,6 +771,13 @@ def test_correlate_another_sycl_queue(self):
700771
with pytest.raises(ValueError):
701772
dpnp.correlate(a, v)
702773

774+
def test_correlate_unkown_method(self):
775+
a = dpnp.arange(5)
776+
v = dpnp.arange(3)
777+
778+
with pytest.raises(ValueError):
779+
dpnp.correlate(a, v, method="unknown")
780+
703781

704782
@pytest.mark.parametrize(
705783
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)

0 commit comments

Comments
 (0)