Skip to content

Commit ea03bde

Browse files
authored
[MLIR][AMDGPU] Adding Vector transfer_read to load rewrite pattern (#131803)
This PR adds the Vector transfer_read to load rewrite pattern. The pattern creates a transfer read op lowering. A vector trasfer read op will be lowered to a combination of `vector.load`, `arith.select` and `vector.broadcast` if: - The transfer op is masked. - The memref is in buffer address space. - Other conditions introduced from `TransferReadToVectorLoadLowering` The motivation of this PR is due to the lack of support of masked load from amdgpu backend. `llvm.intr.masked.load` lower to a series of conditional scalar loads refer to (`scalarize-masked-mem-intrin` pass). This PR will make it possible for masked transfer_read to be lowered towards buffer load with bounds check, allowing a more optimized global load accessing pattern compared with existing implementation of `llvm.intr.masked.load` on vectors.
1 parent 09feaa9 commit ea03bde

File tree

5 files changed

+259
-0
lines changed

5 files changed

+259
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace amdgpu {
2222

2323
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
2424
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
25+
#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
2526
#define GEN_PASS_REGISTRATION
2627
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
2728

@@ -30,6 +31,9 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
3031
Chipset chipset);
3132

3233
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
34+
35+
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns);
36+
3337
} // namespace amdgpu
3438
} // namespace mlir
3539

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,18 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
5151
];
5252
}
5353

54+
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
55+
let summary = "Lower the operations from the vector transfer_read to vector load";
56+
let description = [{
57+
This pass creates a transfer read op lowering. A vector trasfer read op
58+
will be lowered to a combination of vector.load, arith.select and
59+
vector.broadcast.
60+
61+
This pattern will make it possible for masked transfer_read to be lowered
62+
towards buffer load with bounds check, allowing a more optimized global
63+
load accessing pattern compared with existing implementation of
64+
llvm.intr.masked.load on vectors.
65+
}];
66+
let dependentDialects = [];
67+
}
5468
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRAMDGPUTransforms
22
EmulateAtomics.cpp
33
ResolveStridedMetadata.cpp
4+
TransferReadToLoad.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
13+
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/IR/TypeUtilities.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Support/LogicalResult.h"
18+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
19+
20+
namespace mlir::amdgpu {
21+
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
22+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
23+
} // namespace mlir::amdgpu
24+
25+
using namespace mlir;
26+
using namespace mlir::amdgpu;
27+
28+
/// This pattern supports lowering of:
29+
/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
30+
/// `vector.broadcast` if all of the following hold:
31+
/// - The transfer op is masked.
32+
/// - The memref is in buffer address space.
33+
/// - Stride of most minor memref dimension must be 1.
34+
/// - Out-of-bounds masking is not required.
35+
/// - If the memref's element type is a vector type then it coincides with the
36+
/// result type.
37+
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
38+
/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
39+
/// pass.
40+
static LogicalResult transferPreconditions(
41+
PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
42+
bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
43+
if (!xferOp.getMask())
44+
return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
45+
46+
// Permutations are handled by VectorToSCF or
47+
// populateVectorTransferPermutationMapLoweringPatterns.
48+
// We let the 0-d corner case pass-through as it is supported.
49+
SmallVector<unsigned> broadcastedDims;
50+
if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
51+
&broadcastedDims))
52+
return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
53+
54+
auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
55+
if (!memRefType)
56+
return rewriter.notifyMatchFailure(xferOp, "not a memref source");
57+
58+
Attribute addrSpace = memRefType.getMemorySpace();
59+
if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
60+
return rewriter.notifyMatchFailure(xferOp, "no address space");
61+
62+
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
63+
amdgpu::AddressSpace::FatRawBuffer)
64+
return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
65+
66+
// Non-unit strides are handled by VectorToSCF.
67+
if (!memRefType.isLastDimUnitStride())
68+
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
69+
70+
// If there is broadcasting involved then we first load the unbroadcasted
71+
// vector, and then broadcast it with `vector.broadcast`.
72+
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
73+
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
74+
for (unsigned i : broadcastedDims)
75+
unbroadcastedVectorShape[i] = 1;
76+
unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
77+
unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
78+
requiresBroadcasting = !broadcastedDims.empty();
79+
80+
// `vector.load` supports vector types as memref's elements only when the
81+
// resulting vector type is the same as the element type.
82+
auto memrefElTy = memRefType.getElementType();
83+
if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
84+
return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
85+
86+
// Otherwise, element types of the memref and the vector must match.
87+
if (!isa<VectorType>(memrefElTy) &&
88+
memrefElTy != xferOp.getVectorType().getElementType())
89+
return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
90+
91+
// Out-of-bounds dims are handled by MaterializeTransferMask.
92+
if (xferOp.hasOutOfBoundsDim())
93+
return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
94+
95+
if (xferOp.getVectorType().getRank() != 1)
96+
// vector.maskedload operates on 1-D vectors.
97+
return rewriter.notifyMatchFailure(
98+
xferOp, "vector type is not rank 1, can't create masked load, needs "
99+
"VectorToSCF");
100+
101+
return success();
102+
}
103+
104+
namespace {
105+
106+
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
107+
using OpRewritePattern::OpRewritePattern;
108+
109+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
110+
PatternRewriter &rewriter) const override {
111+
112+
bool requiresBroadcasting = false;
113+
VectorType unbroadcastedVectorType;
114+
if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
115+
unbroadcastedVectorType))) {
116+
return failure();
117+
}
118+
119+
Location loc = readOp.getLoc();
120+
Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
121+
readOp.getPadding());
122+
Value load = rewriter.create<vector::LoadOp>(
123+
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
124+
Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
125+
readOp.getMask(), load, fill);
126+
127+
// Insert a broadcasting op if required.
128+
if (requiresBroadcasting) {
129+
res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
130+
res);
131+
}
132+
133+
rewriter.replaceOp(readOp, res);
134+
135+
return success();
136+
}
137+
};
138+
139+
} // namespace
140+
141+
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
142+
RewritePatternSet &patterns) {
143+
patterns.add<TransferReadLowering>(patterns.getContext());
144+
}
145+
146+
struct AmdgpuTransferReadToLoadPass final
147+
: amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
148+
AmdgpuTransferReadToLoadPass> {
149+
void runOnOperation() override {
150+
RewritePatternSet patterns(&getContext());
151+
populateAmdgpuTransferReadToLoadPatterns(patterns);
152+
walkAndApplyPatterns(getOperation(), std::move(patterns));
153+
}
154+
};
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: mlir-opt %s --amdgpu-transfer-read-to-load --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
4+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
5+
// CHECK-SAME: %[[ARG1:.*]]: index
6+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
7+
func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
8+
%cf0 = arith.constant 0.0 : f32
9+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
10+
return %res : vector<4xf32>
11+
}
12+
// CHECK: %[[CST:.*]] = arith.constant 0.0
13+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
14+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
15+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
16+
// CHECK: return %[[SELECT]] : vector<4xf32>
17+
18+
// -----
19+
20+
// CHECK-LABEL: func @transfer_to_maskedload_regular(
21+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
22+
// CHECK-SAME: %[[ARG1:.*]]: index
23+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
24+
func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
25+
%cf0 = arith.constant 0.0 : f32
26+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
27+
return %res : vector<4xf32>
28+
}
29+
// CHECK: %[[CST:.*]] = arith.constant 0.0
30+
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
31+
// CHECK: return %[[RES]] : vector<4xf32>
32+
33+
// -----
34+
35+
// CHECK-LABEL: func @transfer_to_maskedload_addrspace(
36+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space<workgroup>>
37+
// CHECK-SAME: %[[ARG1:.*]]: index
38+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
39+
func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space<workgroup>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
40+
%cf0 = arith.constant 0.0 : f32
41+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
42+
return %res : vector<4xf32>
43+
}
44+
// CHECK: %[[CST:.*]] = arith.constant 0.0
45+
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
46+
// CHECK: return %[[RES]] : vector<4xf32>
47+
48+
// -----
49+
50+
// CHECK-LABEL: func @transfer_broadcasting(
51+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
52+
// CHECK-SAME: %[[ARG1:.*]]: index
53+
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
54+
#broadcast_1d = affine_map<(d0, d1) -> (0)>
55+
func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
56+
%cf0 = arith.constant 0.0 : f32
57+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
58+
{in_bounds = [true], permutation_map = #broadcast_1d}
59+
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
60+
return %res : vector<4xf32>
61+
}
62+
// CHECK: %[[CST:.*]] = arith.constant 0.0
63+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
64+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
65+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
66+
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
67+
// CHECK: return %[[BROADCAST]] : vector<4xf32>
68+
69+
// -----
70+
71+
// CHECK-LABEL: func @transfer_scalar(
72+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
73+
// CHECK-SAME: %[[ARG1:.*]]: index
74+
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
75+
func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
76+
%cf0 = arith.constant 0.0 : f32
77+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
78+
{in_bounds = [true]}
79+
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
80+
return %res : vector<1xf32>
81+
}
82+
// CHECK: %[[CST:.*]] = arith.constant 0.0
83+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
84+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
85+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
86+
// CHECK: return %[[SELECT]] : vector<1xf32>

0 commit comments

Comments
 (0)