-
Notifications
You must be signed in to change notification settings - Fork 129
Open
Labels
Description
Description
Numpy has this function we can probably use for the Blockwise of Dot (Matmul) https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_MatrixProduct2 to replace
This also makes the BatchedOp redundant, so we can save a lot of code:
pytensor/pytensor/tensor/blas.py
Lines 1314 to 1711 in f10a603
class BatchedDot(COp): | |
""" | |
Computes a batch matrix-matrix dot with tensor3 variables | |
batched_dot(a, b)[i] = dot(a[i], b[i]) | |
""" | |
__props__ = () | |
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" | |
def make_node(self, x, y): | |
x = as_tensor_variable(x) | |
y = as_tensor_variable(y) | |
if not ( | |
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) | |
): | |
raise NotImplementedError("Only dense tensor types are supported") | |
if not (x.type.ndim == 3 and y.type.ndim == 3): | |
raise TypeError( | |
f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. " | |
"Consider calling batched_dot instead." | |
) | |
def extract_static_dim(dim_x, dim_y): | |
dims = {dim_x, dim_y} - {None} | |
if len(dims) > 1: | |
# BatchedDot doesn't allow broadcasting | |
raise ValueError( | |
f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}" | |
) | |
elif not dims: | |
return None | |
else: | |
return dims.pop() | |
x_batch_dim, x_row_dim, x_sum_dim = x.type.shape | |
y_batch_dim, y_sum_dim, y_col_dim = y.type.shape | |
batch_dim = extract_static_dim(x_batch_dim, y_batch_dim) | |
# Raise if static sum dimensions do not match | |
_ = extract_static_dim(x_sum_dim, y_sum_dim) | |
out_shape = (batch_dim, x_row_dim, y_col_dim) | |
# Change dtype if needed | |
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) | |
x, y = cast(x, dtype), cast(y, dtype) | |
out = tensor(dtype=dtype, shape=out_shape) | |
return Apply(self, [x, y], [out]) | |
def perform(self, node, inp, out): | |
x, y = inp | |
(z,) = out | |
if x.shape[0] != y.shape[0]: | |
raise TypeError( | |
f"Inputs [{', '.join(map(str, inp))}] must have the" | |
f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inp)}]." | |
) | |
z[0] = np.matmul(x, y) | |
def c_support_code(self, **kwargs): | |
batch_gemm_defn = """ | |
template<typename dtype> | |
bool batch_gemm(void (*gemm)(char*, char*, const int*, const int*, const int*, const dtype*, const dtype*, const int*, const dtype*, const int*, const dtype*, dtype*, const int*), | |
int type_size, PyArrayObject* xs, PyArrayObject* ys, | |
PyArrayObject* zs) { | |
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs); | |
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys); | |
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs); | |
if (Nx[0] != Ny[0]) { | |
PyErr_Format(PyExc_ValueError, | |
"Shape mismatch: batch sizes unequal." | |
" x.shape is (%d, %d, %d)," | |
" y.shape is (%d, %d, %d).", | |
Nx[0], Nx[1], Nx[2], | |
Ny[0], Ny[1], Ny[2]); | |
return 1; | |
} | |
if (Nx[2] != Ny[1]) { | |
PyErr_Format(PyExc_ValueError, | |
"Shape mismatch: summation axis sizes unequal." | |
" x.shape is (%d, %d, %d)," | |
" y.shape is (%d, %d, %d).", | |
Nx[0], Nx[1], Nx[2], | |
Ny[0], Ny[1], Ny[2]); | |
return 1; | |
} | |
/* encode the stride structure of _x,_y,_z into a single integer. */ | |
int unit = 0; | |
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8; | |
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4; | |
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0; | |
/* create appropriate strides for malformed matrices that are row or column | |
* vectors, or empty matrices. | |
* In that case, the value of the stride does not really matter, but | |
* some versions of BLAS insist that: | |
* - they are not smaller than the number of elements in the array, | |
* - they are not 0. | |
*/ | |
int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1); | |
int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1); | |
int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1); | |
int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1); | |
int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1); | |
int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1); | |
dtype* x = (dtype*)PyArray_DATA(xs); | |
dtype* y = (dtype*)PyArray_DATA(ys); | |
dtype* z = (dtype*)PyArray_DATA(zs); | |
dtype a = 1.0; | |
dtype b = 0.0; | |
char N = 'N'; | |
char T = 'T'; | |
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2]; | |
// loop over batch axis | |
for (int i = 0; i < Nz[0]; i++) { | |
switch(unit) | |
{ | |
case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break; | |
case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break; | |
case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break; | |
case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break; | |
case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break; | |
case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break; | |
case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break; | |
case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break; | |
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1; | |
}; | |
x += Sx[0] / type_size; | |
y += Sy[0] / type_size; | |
z += Sz[0] / type_size; | |
} | |
return 0; | |
} | |
""" | |
return blas_header_text() + batch_gemm_defn | |
def c_libraries(self, **kwargs): | |
return ldflags() | |
def c_compile_args(self, **kwargs): | |
return ldflags(libs=False, flags=True) | |
def c_lib_dirs(self, **kwargs): | |
return ldflags(libs=False, libs_dir=True) | |
def c_header_dirs(self, **kwargs): | |
return ldflags(libs=False, include_dir=True) | |
def c_code(self, node, name, inp, out, sub): | |
# Can only compile if linked to blas libraries | |
if len(self.c_libraries()) <= 0: | |
raise NotImplementedError() | |
_x, _y = inp | |
(_z,) = out | |
fail = sub["fail"] | |
# generate contiguity condition | |
def contiguous(var, ndim): | |
strides = f"PyArray_STRIDES({var})" | |
if ndim == 1: | |
return f"{strides}[0] == type_size" | |
ands = " && ".join( | |
f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0" | |
for i in range(1, ndim) | |
) | |
ors = " || ".join(f"{strides}[{i}] == type_size" for i in range(1, ndim)) | |
return f"{ands} && ({ors})" | |
x_ndim, y_ndim, z_ndim = ( | |
node.inputs[0].ndim, | |
node.inputs[1].ndim, | |
node.outputs[0].ndim, | |
) | |
# generate code to allocate output based on runtime input shapes | |
z_dims = [ | |
f"PyArray_DIMS({_x})[0]", | |
f"PyArray_DIMS({_x})[1]", | |
f"PyArray_DIMS({_y})[2]", | |
] | |
z_shape_correct = " && ".join( | |
f"PyArray_DIMS({_z})[{i}] == {dim}" for i, dim in enumerate(z_dims) | |
) | |
z_shape = ", ".join(z_dims) | |
z_contiguous = contiguous(_z, z_ndim) | |
allocate = f""" | |
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous})) | |
{{ | |
npy_intp dims[{z_ndim}] = {{{z_shape}}}; | |
Py_XDECREF({_z}); | |
{_z} = (PyArrayObject*)PyArray_SimpleNew( | |
{z_ndim}, dims, PyArray_TYPE({_x})); | |
if(!{_z}) {{ | |
PyErr_SetString(PyExc_MemoryError, | |
"failed to alloc BatchedDot output"); | |
{fail} | |
}} | |
}} | |
""" | |
# code to reallocate inputs contiguously if necessary | |
contiguate = [] | |
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]: | |
_contiguous = contiguous(var, ndim) | |
contiguate.append( | |
f""" | |
if (!({_contiguous})) {{ | |
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var}); | |
if (!_copy) | |
{fail} | |
Py_XDECREF({var}); | |
{var} = _copy; | |
}} | |
""" | |
) | |
contiguate = "\n".join(contiguate) | |
return f""" | |
int type_num = PyArray_DESCR({_x})->type_num; | |
int type_size = PyArray_ITEMSIZE({_x}); // in bytes | |
if (PyArray_NDIM({_x}) != 3) {{ | |
PyErr_Format(PyExc_NotImplementedError, | |
"rank(x) != 3. rank(x) is %d.", | |
PyArray_NDIM({_x})); | |
{fail}; | |
}} | |
if (PyArray_NDIM({_y}) != 3) {{ | |
PyErr_Format(PyExc_NotImplementedError, | |
"rank(y) != 3. rank(y) is %d.", | |
PyArray_NDIM({_y})); | |
{fail}; | |
}} | |
if ({_z} && PyArray_NDIM({_z}) != 3) {{ | |
PyErr_Format(PyExc_NotImplementedError, | |
"rank(z) != 3. rank(z) is %d.", | |
PyArray_NDIM({_z})); | |
{fail}; | |
}} | |
// allocate output | |
{allocate} | |
// reallocate any noncontiguous arrays or arrays with invalid strides | |
{contiguate} | |
if ((PyArray_DESCR({_x})->type_num != NPY_DOUBLE) | |
&& (PyArray_DESCR({_x})->type_num != NPY_FLOAT)) | |
{{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}} | |
if ((PyArray_DESCR({_y})->type_num != NPY_DOUBLE) | |
&& (PyArray_DESCR({_y})->type_num != NPY_FLOAT)) | |
{{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}} | |
if ((PyArray_DESCR({_z})->type_num != NPY_DOUBLE) | |
&& (PyArray_DESCR({_z})->type_num != NPY_FLOAT)) | |
{{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}} | |
if ((PyArray_DESCR({_x})->type_num != PyArray_DESCR({_y})->type_num) | |
||(PyArray_DESCR({_x})->type_num != PyArray_DESCR({_z})->type_num)) | |
{{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }} | |
switch (type_num) | |
{{ | |
case NPY_FLOAT: | |
if (batch_gemm<float>(sgemm_, type_size, {_x}, {_y}, {_z})) {{ | |
{fail}; | |
}} | |
break; | |
case NPY_DOUBLE: | |
if (batch_gemm<double>(dgemm_, type_size, {_x}, {_y}, {_z})) {{ | |
{fail}; | |
}} | |
break; | |
}} | |
""" | |
def c_code_cache_version(self): | |
from pytensor.tensor.blas_headers import blas_header_version | |
return (6, blas_header_version()) | |
def grad(self, inp, grads): | |
x, y = inp | |
(gz,) = grads | |
xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1)) | |
ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz) | |
# If x or y contain broadcastable dimensions but only one of | |
# them know that a matching dimensions is broadcastable, the | |
# above code don't always return the right broadcast pattern. | |
# This cause problem down the road. See gh-1461. | |
if xgrad.broadcastable != x.broadcastable: | |
xgrad = specify_broadcastable( | |
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) | |
) | |
if ygrad.broadcastable != y.broadcastable: | |
ygrad = specify_broadcastable( | |
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b) | |
) | |
return xgrad, ygrad | |
def R_op(self, inputs, eval_points): | |
# R_op for batched_dot(a, b) evaluated at c for a and d for b is | |
# simply batched_dot(c, b) + batched_dot(a, d) | |
assert len(inputs) == 2 | |
assert len(eval_points) == 2 | |
if eval_points[0] is None and eval_points[1] is None: | |
return [None] | |
test_values_enabled = config.compute_test_value != "off" | |
if test_values_enabled: | |
try: | |
iv0 = pytensor.graph.op.get_test_value(inputs[0]) | |
except TestValueError: | |
pytensor.graph.op.missing_test_message( | |
"first input passed to BatchedDot.R_op has no test value" | |
) | |
test_values_enabled = False | |
try: | |
iv1 = pytensor.graph.op.get_test_value(inputs[1]) | |
except TestValueError: | |
pytensor.graph.op.missing_test_message( | |
"second input passed to BatchedDot.R_op has no test value" | |
) | |
test_values_enabled = False | |
if eval_points[0]: | |
try: | |
ev0 = pytensor.graph.op.get_test_value(eval_points[0]) | |
except TestValueError: | |
pytensor.graph.op.missing_test_message( | |
"first eval point passed to BatchedDot.R_op " | |
"has no test value" | |
) | |
test_values_enabled = False | |
if eval_points[1]: | |
try: | |
ev1 = pytensor.graph.op.get_test_value(eval_points[1]) | |
except TestValueError: | |
pytensor.graph.op.missing_test_message( | |
"second eval point passed to BatchedDot.R_op " | |
"has no test value" | |
) | |
test_values_enabled = False | |
if test_values_enabled: | |
input_values = [iv0, iv1] | |
eval_point_values = [ev0, ev1] | |
for i in range(2): | |
if ( | |
eval_point_values[i] is not None | |
and input_values[i].shape != eval_point_values[i].shape | |
): | |
raise ValueError( | |
"input " | |
+ str(i) | |
+ " and eval_point " | |
+ str(i) | |
+ " to BatchedDot.R_op should have the same shape, but " | |
f"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively" | |
) | |
if eval_points[0]: | |
t1 = self(eval_points[0], inputs[1]) | |
if eval_points[1]: | |
t2 = self(inputs[0], eval_points[1]) | |
if eval_points[0] and eval_points[1]: | |
return [t1 + t2] | |
elif eval_points[0]: | |
return [t1] | |
else: | |
return [t2] | |
def infer_shape(self, fgraph, node, shapes): | |
xshp, yshp = shapes | |
return [xshp[:-1] + yshp[2:]] | |
_batched_dot = BatchedDot() |
Including some rewrites that try to introduce it. May want to have a look at https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_InnerProduct for the respective Blockwise Dot