@@ -814,6 +814,295 @@ let arguments = (ins
814
814
}];
815
815
}
816
816
817
+ class OuterProductWideningBase<string mnemonic,
818
+ list<Type> allowedInputVectorTypes,
819
+ list<Type> allowedResultVectorTypes,
820
+ int numOuterProducts> :
821
+ ArmSME_Op<mnemonic, [
822
+ ArmSMETileOpInterface,
823
+ AttrSizedOperandSegments,
824
+ AllTypesMatch<["lhs", "rhs"]>,
825
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
826
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
827
+ PredOpTrait<
828
+ "both `lhsMask` and `rhsMask` should be provided or neither",
829
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
830
+ >,
831
+ OptionalTypesMatchWith<"`result` and `acc` have the same type",
832
+ "result", "acc", "::llvm::cast<Type>($_self)">,
833
+ // This trait ensures the input types match the correct output type for ops
834
+ // that takes multiple inputs and outputs (i.e., 4-way).
835
+ PredOpTrait<
836
+ "tile element size equals input element size * " # numOuterProducts,
837
+ CPred<"getTileType().getElementTypeBitWidth() == "
838
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
839
+ >,
840
+ ]> {
841
+
842
+ let arguments = (ins
843
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
844
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
845
+ Optional<AnyVector>:$acc);
846
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
847
+
848
+ let assemblyFormat = [{
849
+ $lhs `,` $rhs
850
+ oilist(
851
+ `acc` `` `(` $acc `)`
852
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
853
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
854
+ }];
855
+
856
+ let extraClassDeclaration = [{
857
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861
+ // The outerproduct op allocates a new tile if no accumulator is passed.
862
+ if (!getAcc())
863
+ return arm_sme::getSMETileType(getResultType());
864
+ return std::nullopt;
865
+ }
866
+ VectorType getTileType() {
867
+ return getResultType();
868
+ }
869
+ }];
870
+ }
871
+
872
+ class OuterProduct2Way<string mnemonic,
873
+ list<Type> allowedInputVectorTypes,
874
+ list<Type> allowedResultVectorTypes>
875
+ : OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
876
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
877
+
878
+ def FMopa2WayOp
879
+ : OuterProduct2Way<"fmopa_2way",
880
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
881
+ [nxnxv4f32]> {
882
+ let summary = "Floating-point sum of 2 outer products and accumulate";
883
+
884
+ let description = [{
885
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
886
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
887
+
888
+ For example (fp16 to fp32):
889
+
890
+ ```mlir
891
+ %result = arm_sme.fmopa_2way %lhs, %rhs :
892
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
893
+ ```
894
+
895
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
896
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
897
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
898
+ this operation for fp16 to fp32, SVL=128 (i.e., vscale=1):
899
+
900
+ ```
901
+ LHS RHS
902
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
903
+
904
+ ----------------------------------------------------------------------------
905
+
906
+ implicit layout
907
+
908
+ [A0 A1] |
909
+ [A2 A3] | [B0 B2 B4 B6]
910
+ [A4 A5] | [B1 B3 B5 B7]
911
+ [A6 A7] |
912
+
913
+ ----------------------------------------------------------------------------
914
+
915
+ 2 outer products
916
+
917
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
918
+ ------------- | -------------
919
+ |
920
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
921
+ |
922
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
923
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
924
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
925
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
926
+ |
927
+
928
+ ----------------------------------------------------------------------------
929
+
930
+ sum of 2 outer products
931
+
932
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
933
+
934
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
935
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
936
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
937
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
938
+
939
+ ----------------------------------------------------------------------------
940
+ ```
941
+
942
+ This operation enables the folding of 2 outer products chained via the
943
+ accumulator into a single outer product.
944
+
945
+ For example:
946
+
947
+ ```mlir
948
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
949
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
950
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
951
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
952
+
953
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
954
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
955
+ ```
956
+
957
+ The 2 outer products in the example above can be fused into a single outer
958
+ product as follows:
959
+
960
+ ```mlir
961
+ %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
962
+ %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
963
+ %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
964
+ ```
965
+
966
+ This is implemented in the `-arm-sme-outer-product-fusion` pass.
967
+
968
+ Example: FP16 to FP32
969
+ ```mlir
970
+ %result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
971
+ ```
972
+
973
+ Example: BF16 to FP32
974
+ ```mlir
975
+ %result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
976
+ ```
977
+
978
+ | Spec | Features |
979
+ | ---- | -------- |
980
+ | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
981
+ | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
982
+
983
+ [1] https://developer.arm.com/documentation/ddi0616
984
+ }];
985
+ }
986
+
987
+ // TODO: support:
988
+ // - FMOPA 2-way FP8 to FP16
989
+ // - FMOPA 4-way FP16 to FP32
990
+ // once intrinsic support lands in the backend.
991
+
992
+ def FMops2WayOp
993
+ : OuterProduct2Way<"fmops_2way",
994
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
995
+ [nxnxv4f32]> {
996
+ let summary = "Floating-point sum of 2 outer products and subtract";
997
+ let description = [{
998
+ Equivalent to `fmopa_2way` but outer products are subtracted from
999
+ destination `result`.
1000
+
1001
+ Example: FP16 to FP32
1002
+ ```mlir
1003
+ %result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
1004
+ ```
1005
+
1006
+ Example: BF16 to FP32
1007
+ ```mlir
1008
+ %result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
1009
+
1010
+ Refer to
1011
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1012
+ description of 2-way outer products.
1013
+
1014
+ | Spec | Features |
1015
+ | ---- | -------- |
1016
+ | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
1017
+ | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
1018
+ ```
1019
+ }];
1020
+ }
1021
+
1022
+ def SMopa2WayOp
1023
+ : OuterProduct2Way<"smopa_2way",
1024
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1025
+ [nxnxv4i32]> {
1026
+ let summary = "Signed integer sum of 2 outer products and accumulate";
1027
+ let description = [{
1028
+ Example:
1029
+ ```mlir
1030
+ %result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1031
+
1032
+ Refer to
1033
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1034
+ description of 2-way outer products.
1035
+
1036
+ | Spec | Features |
1037
+ | ---- | -------- |
1038
+ | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1039
+ ```
1040
+ }];
1041
+ }
1042
+
1043
+ def SMops2WayOp
1044
+ : OuterProduct2Way<"smops_2way",
1045
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1046
+ [nxnxv4i32]> {
1047
+ let summary = "Signed integer sum of 2 outer products and subtract";
1048
+ let description = [{
1049
+ Example:
1050
+ ```mlir
1051
+ %result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1052
+
1053
+ Refer to
1054
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1055
+ description of 2-way outer products.
1056
+
1057
+ | Spec | Features |
1058
+ | ---- | -------- |
1059
+ | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1060
+ ```
1061
+ }];
1062
+ }
1063
+
1064
+ def UMopa2WayOp
1065
+ : OuterProduct2Way<"umopa_2way",
1066
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1067
+ [nxnxv4i32]> {
1068
+ let summary = "Unsiged integer sum of 2 outer products and accumulate";
1069
+ let description = [{
1070
+ Example:
1071
+ ```mlir
1072
+ %result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1073
+
1074
+ Refer to
1075
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1076
+ description of 2-way outer products.
1077
+
1078
+ | Spec | Features |
1079
+ | ---- | -------- |
1080
+ | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1081
+ ```
1082
+ }];
1083
+ }
1084
+
1085
+ def UMops2WayOp
1086
+ : OuterProduct2Way<"umops_2way",
1087
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1088
+ [nxnxv4i32]> {
1089
+ let summary = "Unsiged integer sum of 2 outer products and subtract";
1090
+ let description = [{
1091
+ Example:
1092
+ ```mlir
1093
+ %result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1094
+
1095
+ Refer to
1096
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
1097
+ description of 2-way outer products.
1098
+
1099
+ | Spec | Features |
1100
+ | ---- | -------- |
1101
+ | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1102
+ ```
1103
+ }];
1104
+ }
1105
+
817
1106
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
818
1107
{
819
1108
let summary = "Query the streaming vector length";
0 commit comments