Skip to content

Implement batched convolve1d #1318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.signal
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.sort
import pytensor.link.jax.dispatch.sparse
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import pytensor.link.jax.dispatch.signal.conv
14 changes: 14 additions & 0 deletions pytensor/link/jax/dispatch/signal/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import jax

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.signal.conv import Conv1d


@jax_funcify.register(Conv1d)
def jax_funcify_Conv1d(op, node, **kwargs):
mode = op.mode

def conv1d(data, kernel):
return jax.numpy.convolve(data, kernel, mode=mode)

return conv1d
2 changes: 2 additions & 0 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.signal
import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic


# isort: on
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import pytensor.link.numba.dispatch.signal.conv
16 changes: 16 additions & 0 deletions pytensor/link/numba/dispatch/signal/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np

from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Conv1d


@numba_funcify.register(Conv1d)
def numba_funcify_Conv1d(op, node, **kwargs):
mode = op.mode

@numba_njit
def conv1d(data, kernel):
return np.convolve(data, kernel, mode=mode)

Check warning on line 14 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L14

Added line #L14 was not covered by tests

return conv1d
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
# isort: off
from pytensor.tensor import linalg
from pytensor.tensor import special
from pytensor.tensor import signal

# For backward compatibility
from pytensor.tensor import nlinalg
Expand Down
4 changes: 4 additions & 0 deletions pytensor/tensor/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pytensor.tensor.signal.conv import convolve1d


__all__ = ("convolve1d",)
132 changes: 132 additions & 0 deletions pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import TYPE_CHECKING, Literal, cast

from numpy import convolve as numpy_convolve

from pytensor.graph import Apply, Op
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum
from pytensor.tensor.type import vector
from pytensor.tensor.variable import TensorVariable


if TYPE_CHECKING:
from pytensor.tensor import TensorLike


class Conv1d(Op):
__props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)"

def __init__(self, mode: Literal["full", "valid"] = "full"):
if mode not in ("full", "valid"):
raise ValueError(f"Invalid mode: {mode}")

Check warning on line 24 in pytensor/tensor/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/signal/conv.py#L24

Added line #L24 was not covered by tests
self.mode = mode

def make_node(self, in1, in2):
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)

assert in1.ndim == 1
assert in2.ndim == 1

dtype = upcast(in1.dtype, in2.dtype)

n = in1.type.shape[0]
k = in2.type.shape[0]

if n is None or k is None:
out_shape = (None,)
elif self.mode == "full":
out_shape = (n + k - 1,)
else: # mode == "valid":
out_shape = (max(n, k) - min(n, k) + 1,)

out = vector(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2], [out])

def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode)

def infer_shape(self, fgraph, node, shapes):
in1_shape, in2_shape = shapes
n = in1_shape[0]
k = in2_shape[0]
if self.mode == "full":
shape = n + k - 1
else: # mode == "valid":
shape = maximum(n, k) - minimum(n, k) + 1
return [[shape]]

def L_op(self, inputs, outputs, output_grads):
in1, in2 = inputs
[grad] = output_grads

if self.mode == "full":
valid_conv = type(self)(mode="valid")
in1_bar = valid_conv(grad, in2[::-1])
in2_bar = valid_conv(grad, in1[::-1])

else: # mode == "valid":
full_conv = type(self)(mode="full")
n = in1.shape[0]
k = in2.shape[0]
kmn = maximum(0, k - n)
nkm = maximum(0, n - k)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
in1_bar = full_conv(grad, in2[::-1])
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
in2_bar = full_conv(grad, in1[::-1])
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]

return [in1_bar, in2_bar]


def convolve1d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
) -> TensorVariable:
"""Convolve two one-dimensional arrays.

Convolve in1 and in2, with the output size determined by the mode argument.

Parameters
----------
in1 : (..., N,) tensor_like
First input.
in2 : (..., M,) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.

Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.

"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)

if mode == "same":
# We implement "same" as "valid" with padded `in1`.
in1_batch_shape = tuple(in1.shape)[:-1]
zeros_left = in2.shape[0] // 2
zeros_right = (in2.shape[0] - 1) // 2
in1 = join(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pad wasn't useful here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want pad until we figure out inline for it. I want PyTensor to optimize across the boundary, specially when gradients get involved

-1,
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
in1,
zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype),
)
mode = "valid"

return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
Empty file.
18 changes: 18 additions & 0 deletions tests/link/jax/signal/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import pytest

from pytensor.tensor import dmatrix
from pytensor.tensor.signal import convolve1d
from tests.link.jax.test_basic import compare_jax_and_py


@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)

rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
test_x = rng.normal(size=(3, 5))
test_x = rng.normal(size=(3, 5)).astype(floatX)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope no stupid float32 in jax/numba tests and also I explicitly used dmatrix

test_y = rng.normal(size=(7, 11))
compare_jax_and_py([x, y], out, [test_x, test_y])
22 changes: 22 additions & 0 deletions tests/link/numba/signal/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest

from pytensor.tensor import dmatrix
from pytensor.tensor.signal import convolve1d
from tests.link.numba.test_basic import compare_numba_and_py


pytestmark = pytest.mark.filterwarnings("error")


@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)

rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
Empty file added tests/tensor/signal/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/tensor/signal/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from functools import partial

import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve

from pytensor import config, function
from pytensor.tensor import matrix, vector
from pytensor.tensor.signal.conv import convolve1d
from tests import unittest_tools as utt


@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}")
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}")
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode, data_shape, kernel_shape):
data = vector("data")
kernel = vector("kernel")
op = partial(convolve1d, mode=mode)

rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)

fn = function([data, kernel], op(data, kernel))
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve(data_val, kernel_val, mode=mode),
rtol=1e-6 if config.floatX == "float32" else 1e-15,
)
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val])


def test_convolve1d_batch():
x = matrix("data")
y = matrix("kernel")
out = convolve1d(x, y)

rng = np.random.default_rng(38)
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
y_test = x_test[::-1]

res = out.eval({x: x_test, y: y_test})
# Second entry of x, y are just y, x respectively,
# so res[0] and res[1] should be identical.
rtol = 1e-6 if config.floatX == "float32" else 1e-15
res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol)