-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][ArmSVE] Add convert_to/from_svbool ops #68586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && | |
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && | ||
::llvm::cast<VectorType>($_self).isScalable()}]>; | ||
|
||
// Whether a type is a scalable VectorType, with a single trailing scalable dimension. | ||
// Examples: | ||
// Valid: | ||
// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32> | ||
// Invalid | ||
// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32> | ||
def IsOnlyTrailingDimScalablePred : And<[ | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CPred<"::llvm::isa<::mlir::VectorType>($_self)">, | ||
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, | ||
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, | ||
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)"> | ||
]>; | ||
|
||
// Whether a type is a VectorType and all dimensions are scalable. | ||
def allDimsScalableVectorTypePred : And<[ | ||
IsVectorTypePred, | ||
|
@@ -404,6 +417,12 @@ class ScalableVectorOf<list<Type> allowedTypes> : | |
ShapedContainerType<allowedTypes, IsScalableVectorTypePred, | ||
"scalable vector", "::mlir::VectorType">; | ||
|
||
// Any vector with a single trailing scalable dimension, with an element type in | ||
// the `allowedTypes` list. | ||
class TrailingScalableVectorOf<list<Type> allowedTypes> : | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ShapedContainerType<allowedTypes, IsOnlyTrailingDimScalablePred, | ||
"trailing scalable vector", "::mlir::VectorType">; | ||
|
||
// Whether the number of elements of a vector is from the given | ||
// `allowedRanks` list | ||
class IsVectorOfRankPred<list<int> allowedRanks> : | ||
|
@@ -481,10 +500,32 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> : | |
== }] | ||
# allowedlength>)>]>; | ||
|
||
class abs<int value> { | ||
int ret = !if(!lt(value, 0), !sub(0, value), value); | ||
} | ||
|
||
// Whether the n-th (starting from 1) dim of the shape matches the given `size`. | ||
// Negative values index in reverse. | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes> | ||
: And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>, | ||
CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), " | ||
# "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" | ||
# !if(!lt(n, 0), | ||
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, | ||
"" # !sub(n, 1)) | ||
# "))">]>; | ||
|
||
// Whether the shape of a vector matches the given `shape` list. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't "mix" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've moved it to near the end of the list now |
||
class IsVectorOfShape<list<int> shape> | ||
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">; | ||
|
||
// Any ShapedType where the size of the n-th dim is contained in `sizes`. | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Negative values index in reverse. | ||
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type< | ||
IsNthDimSizeIsOneOfPred<n, allowedSizes>, | ||
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this subtract 1 from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, the dims are indexed from 1 (so both reverse and forward indexing is symmetrical). |
||
"::mlir::ShapedType">; | ||
|
||
// Any vector where the number of elements is from the given | ||
// `allowedLengths` list | ||
class VectorOfLength<list<int> allowedLengths> : Type< | ||
|
@@ -546,6 +587,17 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks, | |
ScalableVectorOfLength<allowedLengths>.summary, | ||
"::mlir::VectorType">; | ||
|
||
// Any scalable vector with a single trailing scalable dimensions, where the | ||
// size of the trailing dimension is in `allowedTrailingSizes` list, and the | ||
// type is in the `allowedTypes` list. | ||
class TrailingScalableVectorOfSizeAndType<list<int> allowedTrailingSizes, | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
list<Type> allowedTypes> : AllOfType< | ||
[TrailingScalableVectorOf<allowedTypes>, | ||
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>], | ||
TrailingScalableVectorOf<allowedTypes>.summary # | ||
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, | ||
"::mlir::VectorType">; | ||
|
||
def AnyVector : VectorOf<[AnyType]>; | ||
// Temporary vector type clone that allows gradual transition to 0-D vectors. | ||
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
|
@@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering = | |
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, | ||
ScalableMaskedDivFIntrOp>; | ||
|
||
namespace { | ||
|
||
/// Unrolls a conversion to/from equivalent vector types, to allow using a | ||
/// conversion intrinsic that only supports 1-D vector types. | ||
Comment on lines
+73
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is doing two things and the op doesn't map 1-1 with intrinsics. Based on previous feedback I've received and general observations, I wonder if the unrolling should be done as a separate transform and this simply maps rank 1 It wouldn't have to be done as part of this patch, but something to consider. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the whole reason these op exists though? If this mapped 1-1 to the intrinsics there'd be no reason for these ops to exist. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I understand that, but it doesn't mean both steps have to be done when lowering to intrinsics. At the Vector dialect level for example operations on higher rank vectors are typically broken down to rank 1 vectors first (VectorToSCF), before lowering to intrinsics. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That'll be quite a big rework as the ArmSVE dialect is pretty much just a skeleton, so currently has no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:
|
||
/// | ||
/// Example: | ||
/// ``` | ||
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> | ||
/// ``` | ||
/// is rewritten into: | ||
/// ``` | ||
/// %cst = arith.constant dense<false> : vector<2x[16]xi1> | ||
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> | ||
/// %2 = "arm_sve.intr.convert.to.svbool"(%1) | ||
/// : (vector<[4]xi1>) -> vector<[16]xi1> | ||
/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1> | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> | ||
/// %5 = "arm_sve.intr.convert.to.svbool"(%4) | ||
/// : (vector<[4]xi1>) -> vector<[16]xi1> | ||
/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1> | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// ``` | ||
template <typename Op, typename IntrOp> | ||
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { | ||
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(Op convertOp, typename Op::Adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto loc = convertOp.getLoc(); | ||
|
||
auto source = convertOp.getSource(); | ||
VectorType sourceType = source.getType(); | ||
VectorType resultType = convertOp.getResult().getType(); | ||
|
||
Value result = rewriter.create<arith::ConstantOp>( | ||
loc, resultType, rewriter.getZeroAttr(resultType)); | ||
|
||
// We want to iterate over the input vector in steps of the trailing | ||
// dimension. So this creates tile shape where all leading dimensions are 1, | ||
// and the trailing dimension step is the size of the dimension. | ||
SmallVector<int64_t> tileShape(sourceType.getRank(), 1); | ||
tileShape.back() = sourceType.getShape().back(); | ||
|
||
// Iterate over all scalable mask/predicate slices of the source vector. | ||
for (SmallVector<int64_t> index : | ||
StaticTileOffsetRange(sourceType.getShape(), tileShape)) { | ||
auto extractOrInsertPosition = ArrayRef(index).drop_back(); | ||
auto sourceVector = rewriter.create<vector::ExtractOp>( | ||
loc, source, extractOrInsertPosition); | ||
auto convertedType = | ||
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) | ||
.setDim(0, resultType.getShape().back()); | ||
auto convertedVector = | ||
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); | ||
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, | ||
extractOrInsertPosition); | ||
} | ||
|
||
rewriter.replaceOp(convertOp, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
using ConvertToSvboolOpLowering = | ||
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; | ||
|
||
using ConvertFromSvboolOpLowering = | ||
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; | ||
|
||
} // namespace | ||
|
||
/// Populate the given list with patterns that convert from ArmSVE to LLVM. | ||
void mlir::populateArmSVELegalizeForLLVMExportPatterns( | ||
LLVMTypeConverter &converter, RewritePatternSet &patterns) { | ||
|
@@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( | |
ScalableMaskedMulFOpLowering, | ||
ScalableMaskedSDivIOpLowering, | ||
ScalableMaskedUDivIOpLowering, | ||
ScalableMaskedDivFOpLowering>(converter); | ||
ScalableMaskedDivFOpLowering, | ||
ConvertToSvboolOpLowering, | ||
ConvertFromSvboolOpLowering>(converter); | ||
// clang-format on | ||
} | ||
|
||
|
@@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget( | |
ScalableMaskedMulFIntrOp, | ||
ScalableMaskedSDivIIntrOp, | ||
ScalableMaskedUDivIIntrOp, | ||
ScalableMaskedDivFIntrOp>(); | ||
ScalableMaskedDivFIntrOp, | ||
ConvertToSvboolIntrOp, | ||
ConvertFromSvboolIntrOp>(); | ||
target.addIllegalOp<SdotOp, | ||
SmmlaOp, | ||
UdotOp, | ||
|
@@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget( | |
ScalableMaskedMulFOp, | ||
ScalableMaskedSDivIOp, | ||
ScalableMaskedUDivIOp, | ||
ScalableMaskedDivFOp>(); | ||
ScalableMaskedDivFOp, | ||
ConvertToSvboolOp, | ||
ConvertFromSvboolOp>(); | ||
// clang-format on | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | ||
|
||
// ----- | ||
|
||
func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> { | ||
// expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}} | ||
%mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2> | ||
return %mask : vector<2x[8]xi2> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> { | ||
// expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}} | ||
%mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1> | ||
return %mask : vector<[7]xi1> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> { | ||
// expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}} | ||
%mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1> | ||
return %mask : vector<[4]x[8]xi1> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> { | ||
// expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}} | ||
%bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2> | ||
return %bool : vector<2x[16]xi1> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> { | ||
// expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}} | ||
%bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> { | ||
// expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}} | ||
%bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1> | ||
return | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible to test for this diagnostic, right?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already implicitly tested. This is used to infer the the svbool type (from the result or argument), so there's no textual representation of the op that fails this constraint.