Skip to content

Commit 95ef8e3

Browse files
authored
[mlir][ArmSME] Support 2-way widening outer products (#78975)
This patch introduces support for 2-way widening outer products. This enables the fusion of 2 'arm_sme.outerproduct' operations that are chained via the accumulator into a 2-way widening outer product operation. Changes: - Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants. These map to instruction variants added in SME2 and use different intrinsics. Intrinsics are already implemented for widening variants from SME1. - Adds the following operations: - fmopa_2way, fmops_2way - smopa_2way, smops_2way - umopa_2way, umops_2way - Implements conversions for the above ops to intrinsics in ArmSMEToLLVM. - Adds a pass 'arm-sme-outer-product-fusion' that fuses 'arm_sme.outerproduct' operations. For a detailed description of these operations see the 'arm_sme.fmopa_2way' description. The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this.
1 parent 488f88b commit 95ef8e3

File tree

14 files changed

+1437
-2
lines changed

14 files changed

+1437
-2
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
105105
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
106106
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
107107
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
108+
def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
109+
def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
110+
def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
111+
def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
108112

109113
class ArmSME_IntrLoadStoreOp<string mnemonic>
110114
: ArmSME_IntrOp<mnemonic,

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,295 @@ let arguments = (ins
814814
}];
815815
}
816816

817+
class OuterProductWideningBase<string mnemonic,
818+
list<Type> allowedInputVectorTypes,
819+
list<Type> allowedResultVectorTypes,
820+
int numOuterProducts> :
821+
ArmSME_Op<mnemonic, [
822+
ArmSMETileOpInterface,
823+
AttrSizedOperandSegments,
824+
AllTypesMatch<["lhs", "rhs"]>,
825+
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
826+
HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
827+
PredOpTrait<
828+
"both `lhsMask` and `rhsMask` should be provided or neither",
829+
CPred<"bool(getLhsMask()) == bool(getRhsMask())">
830+
>,
831+
OptionalTypesMatchWith<"`result` and `acc` have the same type",
832+
"result", "acc", "::llvm::cast<Type>($_self)">,
833+
// This trait ensures the input types match the correct output type for ops
834+
// that takes multiple inputs and outputs (i.e., 4-way).
835+
PredOpTrait<
836+
"tile element size equals input element size * " # numOuterProducts,
837+
CPred<"getTileType().getElementTypeBitWidth() == "
838+
"(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
839+
>,
840+
]> {
841+
842+
let arguments = (ins
843+
AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
844+
Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
845+
Optional<AnyVector>:$acc);
846+
let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
847+
848+
let assemblyFormat = [{
849+
$lhs `,` $rhs
850+
oilist(
851+
`acc` `` `(` $acc `)`
852+
| `masks` `` `(` $lhsMask `,` $rhsMask `)`
853+
) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
854+
}];
855+
856+
let extraClassDeclaration = [{
857+
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858+
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859+
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860+
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861+
// The outerproduct op allocates a new tile if no accumulator is passed.
862+
if (!getAcc())
863+
return arm_sme::getSMETileType(getResultType());
864+
return std::nullopt;
865+
}
866+
VectorType getTileType() {
867+
return getResultType();
868+
}
869+
}];
870+
}
871+
872+
class OuterProduct2Way<string mnemonic,
873+
list<Type> allowedInputVectorTypes,
874+
list<Type> allowedResultVectorTypes>
875+
: OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
876+
allowedResultVectorTypes, /*numOuterProducts=*/2>;
877+
878+
def FMopa2WayOp
879+
: OuterProduct2Way<"fmopa_2way",
880+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
881+
[nxnxv4f32]> {
882+
let summary = "Floating-point sum of 2 outer products and accumulate";
883+
884+
let description = [{
885+
This operation represents a sum of 2 widened outer products. It takes 2 1-D
886+
scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
887+
888+
For example (fp16 to fp32):
889+
890+
```mlir
891+
%result = arm_sme.fmopa_2way %lhs, %rhs :
892+
vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
893+
```
894+
895+
The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
896+
2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
897+
elements in a vector of SVL bits. To illustrate, below is a breakdown of
898+
this operation for fp16 to fp32, SVL=128 (i.e., vscale=1):
899+
900+
```
901+
LHS RHS
902+
[A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
903+
904+
----------------------------------------------------------------------------
905+
906+
implicit layout
907+
908+
[A0 A1] |
909+
[A2 A3] | [B0 B2 B4 B6]
910+
[A4 A5] | [B1 B3 B5 B7]
911+
[A6 A7] |
912+
913+
----------------------------------------------------------------------------
914+
915+
2 outer products
916+
917+
Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
918+
------------- | -------------
919+
|
920+
[B0 B2 B4 B6] | [B1 B3 B5 B7]
921+
|
922+
[A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
923+
A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
924+
A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
925+
A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
926+
|
927+
928+
----------------------------------------------------------------------------
929+
930+
sum of 2 outer products
931+
932+
Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
933+
934+
[A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
935+
[A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
936+
[A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
937+
[A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
938+
939+
----------------------------------------------------------------------------
940+
```
941+
942+
This operation enables the folding of 2 outer products chained via the
943+
accumulator into a single outer product.
944+
945+
For example:
946+
947+
```mlir
948+
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
949+
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
950+
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
951+
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
952+
953+
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
954+
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
955+
```
956+
957+
The 2 outer products in the example above can be fused into a single outer
958+
product as follows:
959+
960+
```mlir
961+
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
962+
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
963+
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
964+
```
965+
966+
This is implemented in the `-arm-sme-outer-product-fusion` pass.
967+
968+
Example: FP16 to FP32
969+
```mlir
970+
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
971+
```
972+
973+
Example: BF16 to FP32
974+
```mlir
975+
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
976+
```
977+
978+
| Spec | Features |
979+
| ---- | -------- |
980+
| [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
981+
| [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
982+
983+
[1] https://developer.arm.com/documentation/ddi0616
984+
}];
985+
}
986+
987+
// TODO: support:
988+
// - FMOPA 2-way FP8 to FP16
989+
// - FMOPA 4-way FP16 to FP32
990+
// once intrinsic support lands in the backend.
991+
992+
def FMops2WayOp
993+
: OuterProduct2Way<"fmops_2way",
994+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
995+
[nxnxv4f32]> {
996+
let summary = "Floating-point sum of 2 outer products and subtract";
997+
let description = [{
998+
Equivalent to `fmopa_2way` but outer products are subtracted from
999+
destination `result`.
1000+
1001+
Example: FP16 to FP32
1002+
```mlir
1003+
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
1004+
```
1005+
1006+
Example: BF16 to FP32
1007+
```mlir
1008+
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
1009+
1010+
Refer to
1011+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1012+
description of 2-way outer products.
1013+
1014+
| Spec | Features |
1015+
| ---- | -------- |
1016+
| [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
1017+
| [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
1018+
```
1019+
}];
1020+
}
1021+
1022+
def SMopa2WayOp
1023+
: OuterProduct2Way<"smopa_2way",
1024+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1025+
[nxnxv4i32]> {
1026+
let summary = "Signed integer sum of 2 outer products and accumulate";
1027+
let description = [{
1028+
Example:
1029+
```mlir
1030+
%result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1031+
1032+
Refer to
1033+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1034+
description of 2-way outer products.
1035+
1036+
| Spec | Features |
1037+
| ---- | -------- |
1038+
| [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1039+
```
1040+
}];
1041+
}
1042+
1043+
def SMops2WayOp
1044+
: OuterProduct2Way<"smops_2way",
1045+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1046+
[nxnxv4i32]> {
1047+
let summary = "Signed integer sum of 2 outer products and subtract";
1048+
let description = [{
1049+
Example:
1050+
```mlir
1051+
%result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1052+
1053+
Refer to
1054+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1055+
description of 2-way outer products.
1056+
1057+
| Spec | Features |
1058+
| ---- | -------- |
1059+
| [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1060+
```
1061+
}];
1062+
}
1063+
1064+
def UMopa2WayOp
1065+
: OuterProduct2Way<"umopa_2way",
1066+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1067+
[nxnxv4i32]> {
1068+
let summary = "Unsiged integer sum of 2 outer products and accumulate";
1069+
let description = [{
1070+
Example:
1071+
```mlir
1072+
%result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1073+
1074+
Refer to
1075+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1076+
description of 2-way outer products.
1077+
1078+
| Spec | Features |
1079+
| ---- | -------- |
1080+
| [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1081+
```
1082+
}];
1083+
}
1084+
1085+
def UMops2WayOp
1086+
: OuterProduct2Way<"umops_2way",
1087+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1088+
[nxnxv4i32]> {
1089+
let summary = "Unsiged integer sum of 2 outer products and subtract";
1090+
let description = [{
1091+
Example:
1092+
```mlir
1093+
%result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1094+
1095+
Refer to
1096+
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1097+
description of 2-way outer products.
1098+
1099+
| Spec | Features |
1100+
| ---- | -------- |
1101+
| [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1102+
```
1103+
}];
1104+
}
1105+
8171106
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
8181107
{
8191108
let summary = "Query the streaming vector length";

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
3232
/// Pass that allocates tile IDs to ArmSME operations.
3333
std::unique_ptr<Pass> createTileAllocationPass();
3434

35+
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
36+
/// variants.
37+
std::unique_ptr<Pass> createOuterProductFusionPass();
38+
3539
//===----------------------------------------------------------------------===//
3640
// Registration
3741
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,38 @@ def TileAllocation
122122
let dependentDialects = ["func::FuncDialect"];
123123
}
124124

125+
def OuterProductFusion
126+
: Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> {
127+
let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants";
128+
let description = [{
129+
This pass fuses 'arm_sme.outerproduct' operations that are chained via the
130+
accumulator into 2-way or 4-way ArmSME outer product operations.
131+
132+
For example:
133+
```mlir
134+
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
135+
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
136+
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
137+
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
138+
139+
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
140+
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
141+
```
142+
143+
Becomes:
144+
145+
```mlir
146+
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
147+
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
148+
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
149+
```
150+
151+
For further information on the 2-way or 4-way widening ops see:
152+
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop
153+
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
154+
}];
155+
let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
156+
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
157+
}
158+
125159
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ class LLVMConversionTarget;
1515
class LLVMTypeConverter;
1616
class RewritePatternSet;
1717

18+
namespace arm_sme {
19+
void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
20+
} // namespace arm_sme
21+
1822
} // namespace mlir
1923

2024
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H

0 commit comments

Comments
 (0)