Skip to content

[MLIR][ArmSVE] Add lowering of vector.contract to SVE *MMLA instructions #135359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
Option<"armI8MM", "enable-arm-i8mm",
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
Expand Down
96 changes: 94 additions & 2 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;

def SVEVector : AnyTypeOf<[
Scalable1DVectorOfLength<2, [I64, F64]>,
Scalable1DVectorOfLength<4, [I32, F32]>,
Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
Scalable1DVectorOfLength<16, [I8]>],
"an SVE vector with element size <= 64-bit">;

//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
Expand All @@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
int numResults = 1> :
int numResults = 1,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
/*int numResults=*/numResults>;
/*int numResults=*/numResults,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
/*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;

class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
Expand Down Expand Up @@ -258,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
USMMLA: Unsigned by signed integer matrix multiply-accumulate.

The unsigned by signed integer matrix multiply-accumulate operation
multiplies the 2×8 matrix of unsigned 8-bit integer values held
the first source vector by the 8×2 matrix of signed 8-bit integer
values in the second source vector. The resulting 2×2 widened 32-bit
integer matrix product is then added to the 32-bit integer matrix
accumulator.

Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
Expand Down Expand Up @@ -509,6 +552,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",

def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;

def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
let summary = "Broadcast indexed 128-bit segment to vector";

let description = [{
This operation fills each 128-bit segment of a vector with the elements
from the indexed 128-bit sgement of the source vector. If the VL is
128 bits the operation is a NOP.

Example:
```mlir
// VL == 256
// %X = [A B C D x x x x]
%Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
// Y = [A B C D A B C D]

// %U = [x x x x x x x x A B C D E F G H]
%V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
// %V = [A B C D E F H A B C D E F H]
```
}];

let arguments = (ins SVEVector:$src,
I64Attr:$lane);
let results = (outs SVEVector:$dst);

let builders = [
OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
build($_builder, $_state, src.getType(), src, lane);
}]>];

let assemblyFormat = [{
$src `[` $lane `]` attr-dict `:` type($dst)
}];
}

def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
Expand All @@ -517,6 +595,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
Expand Down Expand Up @@ -610,4 +692,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;

def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
/*traits=*/[],
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/1,
/*immArgPositions*/[1],
/*immArgAttrNames*/["lane"]>,
Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
Arg<I64Attr, "lane">:$lane)>;

#endif // ARMSVE_OPS
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

void populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns);

/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM

MLIRArmNeonDialect
MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
// Avoid scalable vectors.
if (lhsType.isScalable() || rhsType.isScalable())
return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
Expand Down Expand Up @@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
LowerContractionToSVEI8MMPattern.cpp

DEPENDS
MLIRArmSVEConversionsIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
Expand Down Expand Up @@ -192,6 +194,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
UsmmlaOpLowering,
DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
Expand Down Expand Up @@ -219,6 +223,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
UsmmlaIntrOp,
DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
Expand All @@ -238,6 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
UsmmlaOp,
DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
Expand Down
Loading
Loading