Skip to content

Commit b833bcb

Browse files
authored
[mlir][ArmSVE] Add convert_to/from_svbool ops (#68586)
This adds slightly higher-level ops for converting masks between svbool and SVE predicate types. The main reason to use these over the intrinsics is these ops support vectors of masks (via unrolling). E.g. ``` // Convert a svbool mask to a mask of SVE predicates: %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> // => Results in vector<2x[8]xi1> ``` Or: ``` // Convert a mask of SVE predicates to a svbool mask: %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> // => Results in vector<2x[16]xi1> ``` Depends on #68418
1 parent 1c12dcc commit b833bcb

File tree

9 files changed

+448
-5
lines changed

9 files changed

+448
-5
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

+84
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
2828
This dialect contains the definitions necessary to target specific Arm SVE
2929
scalable vector operations.
3030
}];
31+
32+
let dependentDialects = ["vector::VectorDialect"];
3133
}
3234

3335
//===----------------------------------------------------------------------===//
@@ -40,6 +42,13 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
4042
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
4143
[1], [16, 8, 4, 2, 1], [I1]>;
4244

45+
// Generalizations of SVBool and SVEPredicate to ranks >= 1.
46+
// These are masks with a single trailing scalable dimension.
47+
def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
48+
[16], [I1]>;
49+
def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
50+
[16, 8, 4, 2, 1], [I1]>;
51+
4352
//===----------------------------------------------------------------------===//
4453
// ArmSVE op definitions
4554
//===----------------------------------------------------------------------===//
@@ -236,6 +245,81 @@ def UmmlaOp : ArmSVE_Op<"ummla",
236245
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
237246
}
238247

248+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
249+
"expected corresponding svbool type widened to [16]xi1",
250+
lhsArg, rhsArg,
251+
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
252+
253+
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
254+
[Pure, SvboolTypeConstraint<"result", "source">]>
255+
{
256+
let summary = "Convert a svbool type to a SVE predicate type";
257+
let description = [{
258+
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
259+
`vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
260+
dimension can be scalable.
261+
262+
Example 1: Convert a 1-D svbool mask to a SVE predicate.
263+
```mlir
264+
%source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
265+
%result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
266+
```
267+
268+
Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
269+
```mlir
270+
%source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
271+
%result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
272+
```
273+
274+
---
275+
276+
A `svbool` is the smallest SVE predicate type that has a in-memory
277+
representation (and maps to a full predicate register). In MLIR `svbool` is
278+
represented as `vector<[16]xi1>`. Smaller SVE predicate types
279+
(`vector<[1|2|4|8]xi1>`) must be stored as a `svbool` then converted back to
280+
the original predicate type after loading.
281+
}];
282+
let arguments = (ins SVBoolMask:$source);
283+
let results = (outs SVEPredicateMask:$result);
284+
let assemblyFormat = "$source attr-dict `:` type($result)";
285+
}
286+
287+
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
288+
[Pure, SvboolTypeConstraint<"source", "result">]>
289+
{
290+
let summary = "Convert a SVE predicate type to a svbool type";
291+
let description = [{
292+
Converts SVE predicate types (or vectors of predicate types, e.g.
293+
`vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
294+
be scalable.
295+
296+
Example 1: Convert a 1-D SVE predicate to a svbool mask.
297+
```mlir
298+
%source = vector.create_mask %dim_size : vector<[4]xi1>
299+
%result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
300+
// => Results in vector<[16]xi1>
301+
```
302+
303+
Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
304+
```mlir
305+
%source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
306+
%result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
307+
// => Results in vector<2x[16]xi1>
308+
```
309+
310+
---
311+
312+
A `svbool` is the smallest SVE predicate type that has a in-memory
313+
representation (and maps to a full predicate register). In MLIR `svbool` is
314+
represented as `vector<[16]xi1>`. Smaller SVE predicate types
315+
(`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
316+
stored.
317+
}];
318+
let arguments = (ins SVEPredicateMask:$source);
319+
let results = (outs SVBoolMask:$result);
320+
let assemblyFormat = "$source attr-dict `:` type($source)";
321+
}
322+
239323
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
240324
[Commutative]>;
241325

mlir/include/mlir/IR/CommonTypeConstraints.td

+74
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3737
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3838
::llvm::cast<VectorType>($_self).isScalable()}]>;
3939

40+
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
41+
// Examples:
42+
// Valid:
43+
// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
44+
// Invalid
45+
// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
46+
def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
47+
CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
48+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
49+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
50+
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
51+
]>;
52+
4053
// Whether a type is a VectorType and all dimensions are scalable.
4154
def allDimsScalableVectorTypePred : And<[
4255
IsVectorTypePred,
@@ -404,6 +417,15 @@ class ScalableVectorOf<list<Type> allowedTypes> :
404417
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
405418
"scalable vector", "::mlir::VectorType">;
406419

420+
// Any vector with a single trailing scalable dimension, with an element type in
421+
// the `allowedTypes` list.
422+
//
423+
// Note: This Similar to ScalableVectorOf, with the extra requirement that only
424+
// the trailing dim is scalable.
425+
class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
426+
ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred,
427+
"trailing scalable vector", "::mlir::VectorType">;
428+
407429
// Whether the number of elements of a vector is from the given
408430
// `allowedRanks` list
409431
class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,6 +503,40 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
481503
== }]
482504
# allowedlength>)>]>;
483505

506+
// Normalizes an index so the indices in both directions have the same value.
507+
// For example, when indexing forwards index 2 is the third element. When
508+
// indexing in reverse the third element is -3. This helper would map both of
509+
// these to the "normalized" index of 3. This makes the bounds checking in
510+
// IsNthDimSizeIsOneOfPred simpler (see first CPred).
511+
class NormalizeIndex<int value> {
512+
int ret = !if(!lt(value, 0),
513+
!sub(0, value) /* -value if negative */,
514+
!add(value, 1) /* value + 1 if positive*/);
515+
}
516+
517+
// Whether the n-th dim of the shape is contained within `allowedSizes`.
518+
// Negative values for `n` index in reverse.
519+
//
520+
// Examples:
521+
// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
522+
// - Accepts any shape where the first dim is 2, 3, or 4.
523+
// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc
524+
// IsNthDimSizeIsOneOfPred<-1, {16}>
525+
// - Accepts any shape where the last dim is 16.
526+
// * This means shapes like 2x16, 16, 1x2x3x4x16, etc
527+
// IsNthDimSizeIsOneOfPred<-2, {10, 5}>
528+
// - Accepts any shape where the second to last dim is 10 or 5.
529+
// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc
530+
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
531+
: And<[
532+
CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
533+
CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
534+
# "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
535+
# !if(!lt(n, 0),
536+
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
537+
"" # n)
538+
# "))">]>;
539+
484540
// Whether the shape of a vector matches the given `shape` list.
485541
class IsVectorOfShape<list<int> shape>
486542
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
@@ -546,6 +602,24 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
546602
ScalableVectorOfLength<allowedLengths>.summary,
547603
"::mlir::VectorType">;
548604

605+
// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`.
606+
// Negative values for `n` index in reverse.
607+
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
608+
IsNthDimSizeIsOneOfPred<n, allowedSizes>,
609+
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
610+
"::mlir::ShapedType">;
611+
612+
// Any scalable vector with a single trailing scalable dimensions, where the
613+
// size of the trailing dimension is in `allowedTrailingSizes` list, and the
614+
// type is in the `allowedTypes` list.
615+
class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
616+
list<Type> allowedTypes> : AllOfType<
617+
[VectorWithTrailingDimScalableOf<allowedTypes>,
618+
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
619+
VectorWithTrailingDimScalableOf<allowedTypes>.summary #
620+
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
621+
"::mlir::VectorType">;
622+
549623
def AnyVector : VectorOf<[AnyType]>;
550624
// Temporary vector type clone that allows gradual transition to 0-D vectors.
551625
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/DialectImplementation.h"
1718
#include "mlir/IR/OpImplementation.h"

mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
1010
LINK_LIBS PUBLIC
1111
MLIRIR
1212
MLIRLLVMDialect
13+
MLIRVectorDialect
1314
MLIRSideEffectInterfaces
1415
)

mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
77
LINK_LIBS PUBLIC
88
MLIRArmSVEDialect
99
MLIRFuncDialect
10+
MLIRVectorDialect
1011
MLIRIR
1112
MLIRLLVMCommonConversion
1213
MLIRLLVMDialect

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

+82-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
#include "mlir/Dialect/Utils/IndexingUtils.h"
16+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1517
#include "mlir/IR/BuiltinOps.h"
1618
#include "mlir/IR/PatternMatch.h"
1719

@@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering =
6668
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
6769
ScalableMaskedDivFIntrOp>;
6870

71+
namespace {
72+
73+
/// Unrolls a conversion to/from equivalent vector types, to allow using a
74+
/// conversion intrinsic that only supports 1-D vector types.
75+
///
76+
/// Example:
77+
/// ```
78+
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
79+
/// ```
80+
/// is rewritten into:
81+
/// ```
82+
/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
83+
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
84+
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
85+
/// : (vector<[4]xi1>) -> vector<[16]xi1>
86+
/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
87+
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
88+
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
89+
/// : (vector<[4]xi1>) -> vector<[16]xi1>
90+
/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
91+
/// ```
92+
template <typename Op, typename IntrOp>
93+
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
94+
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
95+
96+
LogicalResult
97+
matchAndRewrite(Op convertOp, typename Op::Adaptor,
98+
ConversionPatternRewriter &rewriter) const override {
99+
auto loc = convertOp.getLoc();
100+
101+
auto source = convertOp.getSource();
102+
VectorType sourceType = source.getType();
103+
VectorType resultType = convertOp.getResult().getType();
104+
105+
Value result = rewriter.create<arith::ConstantOp>(
106+
loc, resultType, rewriter.getZeroAttr(resultType));
107+
108+
// We want to iterate over the input vector in steps of the trailing
109+
// dimension. So this creates tile shape where all leading dimensions are 1,
110+
// and the trailing dimension step is the size of the dimension.
111+
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
112+
tileShape.back() = sourceType.getShape().back();
113+
114+
// Iterate over all scalable mask/predicate slices of the source vector.
115+
for (SmallVector<int64_t> index :
116+
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
117+
auto extractOrInsertPosition = ArrayRef(index).drop_back();
118+
auto sourceVector = rewriter.create<vector::ExtractOp>(
119+
loc, source, extractOrInsertPosition);
120+
auto convertedType =
121+
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
122+
.setDim(0, resultType.getShape().back());
123+
auto convertedVector =
124+
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
125+
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
126+
extractOrInsertPosition);
127+
}
128+
129+
rewriter.replaceOp(convertOp, result);
130+
return success();
131+
}
132+
};
133+
134+
using ConvertToSvboolOpLowering =
135+
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
136+
137+
using ConvertFromSvboolOpLowering =
138+
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
139+
140+
} // namespace
141+
69142
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
70143
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
71144
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
@@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
88161
ScalableMaskedMulFOpLowering,
89162
ScalableMaskedSDivIOpLowering,
90163
ScalableMaskedUDivIOpLowering,
91-
ScalableMaskedDivFOpLowering>(converter);
164+
ScalableMaskedDivFOpLowering,
165+
ConvertToSvboolOpLowering,
166+
ConvertFromSvboolOpLowering>(converter);
92167
// clang-format on
93168
}
94169

@@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
107182
ScalableMaskedMulFIntrOp,
108183
ScalableMaskedSDivIIntrOp,
109184
ScalableMaskedUDivIIntrOp,
110-
ScalableMaskedDivFIntrOp>();
185+
ScalableMaskedDivFIntrOp,
186+
ConvertToSvboolIntrOp,
187+
ConvertFromSvboolIntrOp>();
111188
target.addIllegalOp<SdotOp,
112189
SmmlaOp,
113190
UdotOp,
@@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
120197
ScalableMaskedMulFOp,
121198
ScalableMaskedSDivIOp,
122199
ScalableMaskedUDivIOp,
123-
ScalableMaskedDivFOp>();
200+
ScalableMaskedDivFOp,
201+
ConvertToSvboolOp,
202+
ConvertFromSvboolOp>();
124203
// clang-format on
125204
}

0 commit comments

Comments
 (0)