Description
General bilinear operations take the form x^T @ A @ y
where A
is of shape (outdim, xdim, ydim). (See https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html).
Currently, as far as I can tell, implementing this for a sparse A
using pytorch_sparse is only possible as a sparse-dense matmul of A with the outer product of x and y. This outer product is a huge memory cost and, if A is sparse, contains many unnecessary computations.
It seems to me like supporting this computation directly would not require dramatic changes to the existing CPU and CUDA matmul code — seemingly just a second outer loop over x in addition to y. But I'm not really familiar enough with the code to see if that's true.
Ideally, like existing sparse-dense multiplication, it would support a batch dimension in x and y:
x=[batch, dx], A=[dx, dy, dout], y=[batch, dy] --> [batch, dout]
I would be happy to help with implementing this if you think it is doable and something you'd be interested in having in the library!