Skip to content

Commit 9a795f0

Browse files
author
Manish Gupta
committed
[mlir][Vector] Adds a pattern to fold arith.extf into vector.contract
Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32. During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands. "linalg.matmul"(%lhs, %rhs, %acc) ({ ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 "linalg.yield"(%acc) : (f32) -> () }) There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract Differential Revision: https://reviews.llvm.org/D151918
1 parent f04cf6b commit 9a795f0

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
7474
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
7575
PatternBenefit benefit = 1);
7676

77+
/// Collect a set of patterns that fold arithmetic extension on floating point
78+
/// into vector contract for the backends with native support.
79+
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
80+
7781
/// Returns the integer type required for subscripts in the vector dialect.
7882
IntegerType getVectorSubscriptType(Builder &builder);
7983

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,8 +1212,54 @@ struct CanonicalizeContractMatmulToMMT final
12121212
FilterConstraintType filter;
12131213
};
12141214

1215+
/// Pattern to fold arithmetic extensions on floating point data types into
1216+
/// vector contraction operations. linalg.matmul introduces arithmetic
1217+
/// extensions on its operands. Please mlir snippets below for more details.
1218+
/// ```mlir
1219+
/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1220+
/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1221+
/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1222+
/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1223+
/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1224+
/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1225+
/// "linalg.yield"(%acc) : (f32) -> ()
1226+
/// })
1227+
/// ```
1228+
/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1229+
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1230+
/// This pattern folds the arithmetic extensions into the vector contraction and
1231+
/// enables the usage of native mixed precision Tensor Core instructions.
1232+
struct FoldArithExtIntoContractionOp
1233+
: public OpRewritePattern<vector::ContractionOp> {
1234+
using OpRewritePattern::OpRewritePattern;
1235+
1236+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1237+
PatternRewriter &rewriter) const override {
1238+
1239+
auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1240+
auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1241+
1242+
if (!lhsDefOp || !rhsDefOp) {
1243+
return rewriter.notifyMatchFailure(contractOp,
1244+
"no defining op on contract operands");
1245+
}
1246+
1247+
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1248+
contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1249+
contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1250+
contractOp.getIteratorTypesAttr());
1251+
1252+
return success();
1253+
}
1254+
};
1255+
12151256
} // namespace
12161257

1258+
void mlir::vector::populateFoldArithExtensionPatterns(
1259+
RewritePatternSet &patterns) {
1260+
patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
1261+
}
1262+
12171263
void mlir::vector::populateVectorMaskMaterializationPatterns(
12181264
RewritePatternSet &patterns, bool force32BitVectorIndices,
12191265
PatternBenefit benefit) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(test-fold-arith-extf-into-vector-contract-patterns,convert-vector-to-gpu{use-nvgpu=true},cse))" | FileCheck %s
2+
3+
//###############################################################################################
4+
// FP16 input, F32 accumulation row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
5+
//###############################################################################################
6+
7+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
8+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
9+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
10+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
11+
12+
// CHECK-LABEL: func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row
13+
func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16, #gpu.address_space<workgroup>>, %arg1: memref<32x64xf16, #gpu.address_space<workgroup>>, %arg2: memref<42x64xf32, #gpu.address_space<workgroup>>) {
14+
%c0 = arith.constant 0 : index
15+
%c8 = arith.constant 8 : index
16+
%cst_f16 = arith.constant 0.000000e+00 : f16
17+
%cst_f32 = arith.constant 0.000000e+00 : f32
18+
19+
// CHECK-DAG: nvgpu.ldmatrix %arg0[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = false}
20+
%A = vector.transfer_read %arg0[%c0, %c0], %cst_f16 {in_bounds = [true, true]} : memref<42x32xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
21+
%A_f32 = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
22+
23+
24+
// CHECK-DAG: nvgpu.ldmatrix %arg1[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = true}
25+
%B = vector.transfer_read %arg1[%c0, %c0], %cst_f16 {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
26+
%C = vector.transfer_read %arg2[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<42x64xf32, #gpu.address_space<workgroup>>, vector<16x16xf32>
27+
28+
%B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
29+
%B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
30+
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
31+
32+
// CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
33+
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
34+
vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
35+
36+
37+
%B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
38+
%B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
39+
%C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
40+
41+
// CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
42+
%D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
43+
vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
44+
45+
return
46+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt -split-input-file -test-fold-arith-extf-into-vector-contract-patterns %s | FileCheck %s
2+
3+
4+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
5+
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
6+
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
7+
// CHECK-LABEL: func.func @fold_arith_extf_into_contract
8+
// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>)
9+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
10+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
11+
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
12+
// CHECK-NEXT: return %[[R]] : vector<64x64xf32>
13+
func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
14+
%lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
15+
%rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
16+
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
17+
return %result : vector<64x64xf32>
18+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Linalg/Passes.h"
2020
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2223
#include "mlir/Dialect/SCF/IR/SCF.h"
2324
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2425
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -709,6 +710,32 @@ struct TestVectorTransferTensorSlicePatterns
709710
}
710711
};
711712

713+
struct TestFoldArithExtensionIntoVectorContractPatterns
714+
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
715+
OperationPass<func::FuncOp>> {
716+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
717+
TestFoldArithExtensionIntoVectorContractPatterns)
718+
719+
StringRef getArgument() const final {
720+
return "test-fold-arith-extf-into-vector-contract-patterns";
721+
}
722+
StringRef getDescription() const final {
723+
return "Test patterns that fold arithmetic extension ops into vector "
724+
"contract ops";
725+
}
726+
727+
void getDependentDialects(DialectRegistry &registry) const override {
728+
registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
729+
memref::MemRefDialect, scf::SCFDialect,
730+
tensor::TensorDialect, vector::VectorDialect>();
731+
}
732+
733+
void runOnOperation() override {
734+
RewritePatternSet patterns(&getContext());
735+
populateFoldArithExtensionPatterns(patterns);
736+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
737+
}
738+
};
712739
} // namespace
713740

714741
namespace mlir {
@@ -745,6 +772,8 @@ void registerTestVectorLowerings() {
745772
PassRegistration<TestVectorGatherLowering>();
746773

747774
PassRegistration<TestVectorTransferTensorSlicePatterns>();
775+
776+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
748777
}
749778
} // namespace test
750779
} // namespace mlir

0 commit comments

Comments
 (0)