Skip to content

Commit 8a80e33

Browse files
authored
Add isBatchVecmat utilities for linalg.batch_vecmat (#70284)
`linalg.batch_vecmat` was just added in #70218, but I forgot then to add the standard `isBatchVecmat` utilities
1 parent 9db8f99 commit 8a80e33

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
9898
return mlir::isVecmat($_op.getIndexingMaps());
9999
}]>,
100100
InterfaceMethod<
101+
/*desc=*/[{
102+
Returns whether the given op has indexing maps that correspond to a
103+
batched vector-matrix multiplication.
104+
}],
105+
/*retTy=*/"bool",
106+
/*methodName=*/"isBatchVecmat",
107+
/*args=*/(ins),
108+
/*methodBody=*/[{
109+
return mlir::isBatchVecmat($_op.getIndexingMaps());
110+
}]>,
111+
InterfaceMethod<
101112
/*desc=*/[{
102113
Returns whether the given op has indexing maps that correspond to a
103114
matrix-vector multiplication.

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
5555
/// performed within the reduction.
5656
bool isVecmat(ArrayAttr indexingMaps);
5757

58+
/// Tests whether the given maps describe a batch vector matrix multiplication.
59+
/// The test is permutation-invariant. Note that this only checks the affine
60+
/// maps from an operation, so does not perform any checks on the math being
61+
/// performed within the reduction.
62+
bool isBatchVecmat(ArrayAttr indexingMaps);
63+
5864
/// Tests whether the given maps describe a matrix vector multiplication. The
5965
/// test is permutation-invariant. Note that this only checks the affine maps
6066
/// from an operation, so does not perform any checks on the math being

mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,31 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
120120
return indexingMaps == maps;
121121
}
122122

123+
bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
124+
if (indexingMaps.size() != 3)
125+
return false;
126+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129+
130+
if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
131+
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
132+
map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
133+
return false;
134+
}
135+
136+
// Extract dimensions for B*K * B*K*N -> B*N
137+
AffineExpr b = map0.getResult(0);
138+
AffineExpr k = map0.getResult(1);
139+
AffineExpr n = map2.getResult(1);
140+
auto *context = indexingMaps.getContext();
141+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
142+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
143+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
144+
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
145+
return indexingMaps == maps;
146+
}
147+
123148
bool mlir::isMatvec(ArrayAttr indexingMaps) {
124149
if (indexingMaps.size() != 3)
125150
return false;

mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,56 @@ TEST(isBatchMatvec, WrongDimOrderMatrix) {
370370
EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
371371
}
372372

373+
TEST(isBatchVecmat, Simple) {
374+
MLIRContext context;
375+
376+
AffineExpr batch, k, n;
377+
bindDims(&context, batch, k, n);
378+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
379+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
380+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
381+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
382+
383+
EXPECT_THAT(maps, Truly(isBatchVecmat));
384+
}
385+
386+
TEST(isBatchVecmat, BindingSwapped) {
387+
MLIRContext context;
388+
389+
AffineExpr batch, k, n;
390+
bindDims(&context, batch, n, k); // bind in different order
391+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
392+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
393+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
394+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
395+
396+
EXPECT_THAT(maps, Truly(isBatchVecmat));
397+
}
398+
399+
TEST(isBatchVecmat, Matmul) {
400+
MLIRContext context;
401+
402+
AffineExpr m, n, k;
403+
bindDims(&context, m, n, k);
404+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
405+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
406+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
407+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
408+
409+
EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
410+
}
411+
412+
TEST(isBatchVecmat, WrongDimOrderMatrix) {
413+
MLIRContext context;
414+
415+
AffineExpr batch, k, n;
416+
bindDims(&context, batch, k, n);
417+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
418+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
419+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
420+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
421+
422+
EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
423+
}
424+
373425
} // namespace

0 commit comments

Comments
 (0)