Skip to content

Commit 4f93456

Browse files
committed
[mlir][ArmSVE] Add convert_to/from_svbool ops
This adds slightly higher-level ops for converting masks between svbool and SVE predicate types. The main reason to use these over the intrinsics is these ops support vectors of masks (via unrolling). E.g. ``` // Convert a svbool mask to a mask of SVE predicates: %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> // => Results in vector<2x[8]xi1> ``` Or: ``` // Convert a mask of SVE predicates to a svbool mask: %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> // => Results in vector<2x[16]xi1> ```
1 parent be81f42 commit 4f93456

File tree

9 files changed

+343
-5
lines changed

9 files changed

+343
-5
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
2828
This dialect contains the definitions necessary to target specific Arm SVE
2929
scalable vector operations.
3030
}];
31+
32+
let dependentDialects = ["vector::VectorDialect"];
3133
}
3234

3335
//===----------------------------------------------------------------------===//
@@ -40,6 +42,11 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
4042
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
4143
[1], [16, 8, 4, 2, 1], [I1]>;
4244

45+
// Generalizations of SVBool and SVEPredicate to ranks >= 1.
46+
// These are masks with a single trailing scalable dimension.
47+
def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
48+
def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
49+
4350
//===----------------------------------------------------------------------===//
4451
// ArmSVE op definitions
4552
//===----------------------------------------------------------------------===//
@@ -236,6 +243,66 @@ def UmmlaOp : ArmSVE_Op<"ummla",
236243
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
237244
}
238245

246+
247+
class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith<
248+
"expected corresponding svbool type widened to [16]xi1",
249+
lhsArg, rhsArg,
250+
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
251+
252+
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
253+
[Pure, SvboolTypeContraint<"result", "source">]>
254+
{
255+
let summary = "Convert a svbool type to a SVE predicate type";
256+
let description = [{
257+
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
258+
`vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
259+
dimension can be scalable.
260+
261+
Example 1: Convert a 1-D svbool mask to a SVE predicate.
262+
```mlir
263+
%svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
264+
%mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1>
265+
```
266+
267+
Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
268+
```mlir
269+
%svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
270+
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
271+
```
272+
}];
273+
let arguments = (ins SVBoolMask:$source);
274+
let results = (outs SVEMask:$result);
275+
let assemblyFormat = "$source attr-dict `:` type($result)";
276+
}
277+
278+
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
279+
[Pure, SvboolTypeContraint<"source", "result">]>
280+
{
281+
let summary = "Convert a predicate type to a svbool type";
282+
let description = [{
283+
Converts SVE predicate types (or vectors of predicate types, e.g.
284+
`vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
285+
be scalable.
286+
287+
Example 1: Convert a 1-D SVE predicate to a svbool mask.
288+
```mlir
289+
%mask = vector.create_mask %dim_size : vector<[4]xi1>
290+
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
291+
// => Results in vector<[16]xi1>
292+
```
293+
294+
Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
295+
```mlir
296+
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
297+
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
298+
// => Results in vector<2x[16]xi1>
299+
```
300+
}];
301+
let arguments = (ins SVEMask:$source);
302+
let results = (outs SVBoolMask:$result);
303+
let assemblyFormat = "$source attr-dict `:` type($source)";
304+
}
305+
239306
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
240307
[Commutative]>;
241308

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3737
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3838
::llvm::cast<VectorType>($_self).isScalable()}]>;
3939

40+
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
41+
def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
42+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
43+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
44+
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>;
45+
4046
// Whether a type is a VectorType and all dimensions are scalable.
4147
def allDimsScalableVectorTypePred : And<[
4248
IsVectorTypePred,
@@ -404,6 +410,10 @@ class ScalableVectorOf<list<Type> allowedTypes> :
404410
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
405411
"scalable vector", "::mlir::VectorType">;
406412

413+
class TrailingScalableVectorOf<list<Type> allowedTypes> :
414+
ShapedContainerType<allowedTypes, IsTrailingScalableVectorTypePred,
415+
"trailing scalable vector", "::mlir::VectorType">;
416+
407417
// Whether the number of elements of a vector is from the given
408418
// `allowedRanks` list
409419
class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,10 +491,32 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
481491
== }]
482492
# allowedlength>)>]>;
483493

494+
class abs<int value> {
495+
int ret = !if(!lt(value, 0), !sub(0, value), value);
496+
}
497+
498+
// Whether the n-th (starting from 1) dim of the shape matches the given `size`.
499+
// Negative values index in reverse.
500+
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
501+
: And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>,
502+
CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
503+
# "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
504+
# !if(!lt(n, 0),
505+
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
506+
"" # !sub(n, 1))
507+
# "))">]>;
508+
484509
// Whether the shape of a vector matches the given `shape` list.
485510
class IsVectorOfShape<list<int> shape>
486511
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
487512

513+
// Any ShapedType where the size of the n-th dim is contained in `sizes`.
514+
// Negative values index in reverse.
515+
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
516+
IsNthDimSizeIsOneOfPred<n, allowedSizes>,
517+
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
518+
"::mlir::ShapedType">;
519+
488520
// Any vector where the number of elements is from the given
489521
// `allowedLengths` list
490522
class VectorOfLength<list<int> allowedLengths> : Type<
@@ -546,6 +578,17 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
546578
ScalableVectorOfLength<allowedLengths>.summary,
547579
"::mlir::VectorType">;
548580

581+
// Any scalable vector with a single trailing scalable dimensions, where the
582+
// size of the trailing dimension is in `allowedTrailingSizes` list, and the
583+
// type is in the `allowedTypes` list.
584+
class TrailingScalableVectorOfSizeAndType<list<int> allowedTrailingSizes,
585+
list<Type> allowedTypes> : AllOfType<
586+
[TrailingScalableVectorOf<allowedTypes>,
587+
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
588+
TrailingScalableVectorOf<allowedTypes>.summary #
589+
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
590+
"::mlir::VectorType">;
591+
549592
def AnyVector : VectorOf<[AnyType]>;
550593
// Temporary vector type clone that allows gradual transition to 0-D vectors.
551594
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/DialectImplementation.h"
1718
#include "mlir/IR/OpImplementation.h"

mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
1010
LINK_LIBS PUBLIC
1111
MLIRIR
1212
MLIRLLVMDialect
13+
MLIRVectorDialect
1314
MLIRSideEffectInterfaces
1415
)

mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
77
LINK_LIBS PUBLIC
88
MLIRArmSVEDialect
99
MLIRFuncDialect
10+
MLIRVectorDialect
1011
MLIRIR
1112
MLIRLLVMCommonConversion
1213
MLIRLLVMDialect

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
#include "mlir/Dialect/Utils/IndexingUtils.h"
16+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1517
#include "mlir/IR/BuiltinOps.h"
1618
#include "mlir/IR/PatternMatch.h"
1719

@@ -66,6 +68,54 @@ using ScalableMaskedDivFOpLowering =
6668
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
6769
ScalableMaskedDivFIntrOp>;
6870

71+
namespace {
72+
73+
template <typename Op, typename IntrOp>
74+
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
75+
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
76+
77+
LogicalResult
78+
matchAndRewrite(Op convertOp, typename Op::Adaptor,
79+
ConversionPatternRewriter &rewriter) const override {
80+
auto loc = convertOp.getLoc();
81+
82+
auto source = convertOp.getSource();
83+
VectorType sourceType = source.getType();
84+
VectorType resultType = convertOp.getResult().getType();
85+
86+
Value result = rewriter.create<arith::ConstantOp>(
87+
loc, resultType, rewriter.getZeroAttr(resultType));
88+
89+
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
90+
tileShape.back() = sourceType.getShape().back();
91+
92+
for (SmallVector<int64_t> index :
93+
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
94+
auto extractOrInsertPosition = ArrayRef(index).drop_back();
95+
auto sourceVector = rewriter.create<vector::ExtractOp>(
96+
loc, source, extractOrInsertPosition);
97+
auto convertedType =
98+
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
99+
.setDim(0, resultType.getShape().back());
100+
auto convertedVector =
101+
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
102+
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
103+
extractOrInsertPosition);
104+
}
105+
106+
rewriter.replaceOp(convertOp, result);
107+
return success();
108+
}
109+
};
110+
111+
using ConvertToSvboolOpLowering =
112+
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
113+
114+
using ConvertFromSvboolOpLowering =
115+
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
116+
117+
} // namespace
118+
69119
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
70120
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
71121
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
@@ -88,7 +138,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
88138
ScalableMaskedMulFOpLowering,
89139
ScalableMaskedSDivIOpLowering,
90140
ScalableMaskedUDivIOpLowering,
91-
ScalableMaskedDivFOpLowering>(converter);
141+
ScalableMaskedDivFOpLowering,
142+
ConvertToSvboolOpLowering,
143+
ConvertFromSvboolOpLowering>(converter);
92144
// clang-format on
93145
}
94146

@@ -107,7 +159,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
107159
ScalableMaskedMulFIntrOp,
108160
ScalableMaskedSDivIIntrOp,
109161
ScalableMaskedUDivIIntrOp,
110-
ScalableMaskedDivFIntrOp>();
162+
ScalableMaskedDivFIntrOp,
163+
ConvertToSvboolIntrOp,
164+
ConvertFromSvboolIntrOp>();
111165
target.addIllegalOp<SdotOp,
112166
SmmlaOp,
113167
UdotOp,
@@ -120,6 +174,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
120174
ScalableMaskedMulFOp,
121175
ScalableMaskedSDivIOp,
122176
ScalableMaskedUDivIOp,
123-
ScalableMaskedDivFOp>();
177+
ScalableMaskedDivFOp,
178+
ConvertToSvboolOp,
179+
ConvertFromSvboolOp>();
124180
// clang-format on
125181
}

mlir/test/Dialect/ArmSVE/invalid.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// -----
4+
5+
func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> {
6+
// expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
7+
%mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2>
8+
return %mask : vector<2x[8]xi2>
9+
}
10+
11+
// -----
12+
13+
func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> {
14+
// expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
15+
%mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1>
16+
return %mask : vector<[7]xi1>
17+
}
18+
19+
// -----
20+
21+
func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> {
22+
// expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
23+
%mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1>
24+
return %mask : vector<[4]x[8]xi1>
25+
}
26+
27+
// -----
28+
29+
func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> {
30+
// expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
31+
%bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2>
32+
return %bool : vector<2x[16]xi1>
33+
}
34+
35+
// -----
36+
37+
func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> {
38+
// expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
39+
%bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1>
40+
return
41+
}
42+
43+
// -----
44+
45+
func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> {
46+
// expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
47+
%bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1>
48+
return
49+
}
50+
51+

0 commit comments

Comments
 (0)