Skip to content

Commit 9c656ca

Browse files
committed
Fixups
1 parent 4f93456 commit 9c656ca

File tree

4 files changed

+69
-22
lines changed

4 files changed

+69
-22
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
4545
// Generalizations of SVBool and SVEPredicate to ranks >= 1.
4646
// These are masks with a single trailing scalable dimension.
4747
def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
48-
def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
48+
def SVEPredicateMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
4949

5050
//===----------------------------------------------------------------------===//
5151
// ArmSVE op definitions
@@ -243,14 +243,13 @@ def UmmlaOp : ArmSVE_Op<"ummla",
243243
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
244244
}
245245

246-
247-
class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith<
246+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
248247
"expected corresponding svbool type widened to [16]xi1",
249248
lhsArg, rhsArg,
250249
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
251250

252251
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
253-
[Pure, SvboolTypeContraint<"result", "source">]>
252+
[Pure, SvboolTypeConstraint<"result", "source">]>
254253
{
255254
let summary = "Convert a svbool type to a SVE predicate type";
256255
let description = [{
@@ -260,45 +259,61 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
260259

261260
Example 1: Convert a 1-D svbool mask to a SVE predicate.
262261
```mlir
263-
%svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
264-
%mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1>
262+
%source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
263+
%result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
265264
```
266265

267266
Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
268267
```mlir
269-
%svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
270-
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
268+
%source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
269+
%result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
271270
```
271+
272+
---
273+
274+
A `svbool` is the smallest SVE predicate type that has a in-memory
275+
representation (and maps to a full predicate register). In MLIR `svbool` is
276+
represented as `vector<[16]xi1>`. Smaller SVE predicate types
277+
(`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to
278+
a predicate after loading.
272279
}];
273280
let arguments = (ins SVBoolMask:$source);
274-
let results = (outs SVEMask:$result);
281+
let results = (outs SVEPredicateMask:$result);
275282
let assemblyFormat = "$source attr-dict `:` type($result)";
276283
}
277284

278285
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
279-
[Pure, SvboolTypeContraint<"source", "result">]>
286+
[Pure, SvboolTypeConstraint<"source", "result">]>
280287
{
281-
let summary = "Convert a predicate type to a svbool type";
288+
let summary = "Convert a SVE predicate type to a svbool type";
282289
let description = [{
283290
Converts SVE predicate types (or vectors of predicate types, e.g.
284291
`vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
285292
be scalable.
286293

287294
Example 1: Convert a 1-D SVE predicate to a svbool mask.
288295
```mlir
289-
%mask = vector.create_mask %dim_size : vector<[4]xi1>
290-
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
296+
%source = vector.create_mask %dim_size : vector<[4]xi1>
297+
%result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
291298
// => Results in vector<[16]xi1>
292299
```
293300

294301
Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
295302
```mlir
296-
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
297-
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
303+
%source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
304+
%result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
298305
// => Results in vector<2x[16]xi1>
299306
```
307+
308+
---
309+
310+
A `svbool` is the smallest SVE predicate type that has a in-memory
311+
representation (and maps to a full predicate register). In MLIR `svbool` is
312+
represented as `vector<[16]xi1>`. Smaller SVE predicate types
313+
(`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
314+
stored.
300315
}];
301-
let arguments = (ins SVEMask:$source);
316+
let arguments = (ins SVEPredicateMask:$source);
302317
let results = (outs SVBoolMask:$result);
303318
let assemblyFormat = "$source attr-dict `:` type($source)";
304319
}

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,17 @@ def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &
3838
::llvm::cast<VectorType>($_self).isScalable()}]>;
3939

4040
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
41-
def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
42-
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
43-
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
44-
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>;
41+
// Examples:
42+
// Valid:
43+
// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
44+
// Invalid
45+
// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
46+
def IsOnlyTrailingDimScalablePred : And<[
47+
CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
48+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
49+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
50+
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
51+
]>;
4552

4653
// Whether a type is a VectorType and all dimensions are scalable.
4754
def allDimsScalableVectorTypePred : And<[
@@ -410,8 +417,10 @@ class ScalableVectorOf<list<Type> allowedTypes> :
410417
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
411418
"scalable vector", "::mlir::VectorType">;
412419

420+
// Any vector with a single trailing scalable dimension, with an element type in
421+
// the `allowedTypes` list.
413422
class TrailingScalableVectorOf<list<Type> allowedTypes> :
414-
ShapedContainerType<allowedTypes, IsTrailingScalableVectorTypePred,
423+
ShapedContainerType<allowedTypes, IsOnlyTrailingDimScalablePred,
415424
"trailing scalable vector", "::mlir::VectorType">;
416425

417426
// Whether the number of elements of a vector is from the given

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ using ScalableMaskedDivFOpLowering =
7070

7171
namespace {
7272

73+
/// Unrolls a conversion to/from equivalent vector types, to allow using a
74+
/// conversion intrinsic that only supports 1-D vector types.
75+
///
76+
/// Example:
77+
/// ```
78+
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
79+
/// ```
80+
/// is rewritten into:
81+
/// ```
82+
/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
83+
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
84+
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
85+
/// : (vector<[4]xi1>) -> vector<[16]xi1>
86+
/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1>
87+
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
88+
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
89+
/// : (vector<[4]xi1>) -> vector<[16]xi1>
90+
/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1>
91+
/// ```
7392
template <typename Op, typename IntrOp>
7493
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
7594
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
@@ -86,9 +105,13 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
86105
Value result = rewriter.create<arith::ConstantOp>(
87106
loc, resultType, rewriter.getZeroAttr(resultType));
88107

108+
// We want to iterate over the input vector in steps of the trailing
109+
// dimension. So this creates tile shape where all leading dimensions are 1,
110+
// and the trailing dimension step is the size of the dimension.
89111
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
90112
tileShape.back() = sourceType.getShape().back();
91113

114+
// Iterate over all scalable mask/predicate slices of the source vector.
92115
for (SmallVector<int64_t> index :
93116
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
94117
auto extractOrInsertPosition = ArrayRef(index).drop_back();

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -canonicalize -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
22

33
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
44
%b: vector<[16]xi8>,

0 commit comments

Comments
 (0)