Skip to content

linalg.solve broadcasting behavior is ambiguous #285

Closed
@asmeurer

Description

@asmeurer

The spec for the linalg.solve function seems ambiguous. In solve(x1, x2), x1 has shape (..., M, M) and x2 either has shape (..., M) or (..., M, K). In either case, the ... parts should be broadcast compatible.

This is ambiguous. For example, if x1 is shape (2, 2, 2) and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in (2, 2, 2, 2).

Regarding NumPy, it seems to sometimes pick one over the other, even when only the other one makes sense. For example

>>> x1 = np.eye(1)
>>> x2 = np.asarray([[0.], [0.]])
>>> x1.shape
(1, 1)
>>> x2.shape
(2, 1)
>>> np.linalg.solve(x1, x2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<__array_function__ internals>", line 5, in solve
  File "/Users/aaronmeurer/anaconda3/envs/array-apis/lib/python3.9/site-packages/numpy/linalg/linalg.py", line 393, in solve
    r = gufunc(a, b, signature=signature, extobj=extobj)
ValueError: solve: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m,n)->(m,n) (size 2 is different from 1)

Here it wants to treat x2 as a single 2x1 matrix, which is shape incompatible with the 1x1 x1, but it could also treat it a (2,) stacks of length 1 vectors.

I think there are also some issues with the way the spec describes broadcasting. It says "shape(x2)[:-1] must be compatible with shape(x1)[:-1]" but I think this should be shape(x2)[:-2] and so on, since matrix dimensions should never broadcast with each other. It also says that the output should always have same shape as x2, which contradicts that the inputs should broadcast together.

If I am reading the pytorch docs correctly, it resolves this by only allowing broadcasting in the case where x2 is exactly 1- or 2-dimensional. Otherwise when x2 is a stack of matrices, the stack part of the shape has to match the stack part of shape(x1) exactly.

However, I think this still is ambiguous in the case I noted above where x1 is (2, 2, 2) and x2 is (2, 2). x2 could be a matrix, which would broadcast, or a stack of a (2,) matrix, which has a matching stack shape as x1.

So I think more is required to disambiguate, e.g., only allow broadcasting for matrices and not for vectors. One could also remove the vector case completely, or only allow it in the sample case of x2 being 1-D (i.e., no stacks of 1-D vectors).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions