Skip to content

Commit 29d1aca

Browse files
authored
[AMDGPU][MLIR]Add shmem-optimization as an op using transform dialect (llvm#81550)
This PR adds functionality to use shared memory optimization as an op using transform dialect.
1 parent 7180c23 commit 29d1aca

File tree

12 files changed

+356
-19
lines changed

12 files changed

+356
-19
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
2+
add_subdirectory(TransformOps)
23
add_subdirectory(Transforms)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- AMDGPUTransformOps.h - AMDGPU transform ops ---------------*- C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMOPS_AMDGPUTRANSFORMOPS_H
10+
#define MLIR_DIALECT_AMDGPU_TRANSFORMOPS_AMDGPUTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
14+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
15+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
16+
#include "mlir/IR/OpImplementation.h"
17+
#include "mlir/IR/RegionKindInterface.h"
18+
19+
namespace mlir {
20+
namespace transform {
21+
class TransformHandleTypeInterface;
22+
} // namespace transform
23+
} // namespace mlir
24+
25+
namespace mlir {
26+
class DialectRegistry;
27+
28+
namespace linalg {
29+
class LinalgOp;
30+
} // namespace linalg
31+
32+
namespace scf {
33+
class ForOp;
34+
} // namespace scf
35+
36+
namespace amdgpu {
37+
void registerTransformDialectExtension(DialectRegistry &registry);
38+
} // namespace amdgpu
39+
} // namespace mlir
40+
41+
//===----------------------------------------------------------------------===//
42+
// AMDGPU Transform Operations
43+
//===----------------------------------------------------------------------===//
44+
45+
#define GET_OP_CLASSES
46+
#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h.inc"
47+
48+
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMOPS_AMDGPUTRANSFORMOPS_H
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- AMDGPUTransformOps.td - AMDGPU transform ops --------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef AMDGPU_TRANSFORM_OPS
10+
#define AMDGPU_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
13+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
14+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
15+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
18+
//===----------------------------------------------------------------------===//
19+
// ApplyOptimizeSharedMemoryReadsAndWritesOp
20+
//===----------------------------------------------------------------------===//
21+
22+
def ApplyOptimizeSharedMemoryReadsAndWritesOp :
23+
Op<Transform_Dialect, "amdgpu.optimize_shared_memory_reads_and_writes",
24+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
25+
TransformOpInterface, TransformEachOpTrait]> {
26+
let summary = "Reduce shared memory bank conflicts";
27+
let description = [{ This op attempts to optimize GPU Shared memory
28+
reads/writes with the goal of avoiding bank conflicts.
29+
}];
30+
31+
let arguments = (ins TransformHandleTypeInterface:$target);
32+
let results = (outs);
33+
34+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
35+
36+
let extraClassDeclaration = [{
37+
::mlir::DiagnosedSilenceableFailure applyToOne(
38+
::mlir::transform::TransformRewriter &rewriter,
39+
::mlir::func::FuncOp funcOp,
40+
::mlir::transform::ApplyToEachResultList &results,
41+
::mlir::transform::TransformState &state);
42+
}];
43+
}
44+
45+
#endif // AMDGPU_TRANSFORM_OPS
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS AMDGPUTransformOps.td)
2+
mlir_tablegen(AMDGPUTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(AMDGPUTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRAMDGPUTransformOpsIncGen)

mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
1515
#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
1616

17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1718
#include "mlir/IR/Operation.h"
1819
#include "mlir/Support/LogicalResult.h"
1920

@@ -48,6 +49,8 @@ namespace amdgpu {
4849
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
4950
Value memrefValue);
5051

52+
void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
53+
5154
} // namespace amdgpu
5255
} // namespace mlir
5356

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2424
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
2525
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
26+
#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
2627
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
2728
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
2829
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -66,6 +67,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6667
ub::registerConvertUBToLLVMInterface(registry);
6768

6869
// Register all transform dialect extensions.
70+
amdgpu::registerTransformDialectExtension(registry);
6971
affine::registerTransformDialectExtension(registry);
7072
bufferization::registerTransformDialectExtension(registry);
7173
func::registerTransformDialectExtension(registry);
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(IR)
2-
add_subdirectory(Transforms)
32
add_subdirectory(Utils)
3+
add_subdirectory(TransformOps)
4+
add_subdirectory(Transforms)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===- AMDGPUTransformOps.cpp - Implementation of AMDGPU transform ops-----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
10+
11+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
13+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::amdgpu;
18+
using namespace mlir::transform;
19+
using namespace mlir::func;
20+
21+
#define DEBUG_TYPE "amdgpu-transforms"
22+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
23+
#define DBGSNL() (llvm::dbgs() << "\n")
24+
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
25+
26+
DiagnosedSilenceableFailure
27+
ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
28+
TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
29+
TransformState &state) {
30+
optimizeSharedMemoryReadsAndWritesOp(funcOp);
31+
return DiagnosedSilenceableFailure::success();
32+
}
33+
34+
void ApplyOptimizeSharedMemoryReadsAndWritesOp::getEffects(
35+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
36+
onlyReadsHandle(getTarget(), effects);
37+
modifiesPayload(effects);
38+
}
39+
40+
//===----------------------------------------------------------------------===//
41+
// Transform op registration
42+
//===----------------------------------------------------------------------===//
43+
44+
namespace {
45+
class AMDGPUTransformDialectExtension
46+
: public TransformDialectExtension<AMDGPUTransformDialectExtension> {
47+
public:
48+
AMDGPUTransformDialectExtension() {
49+
declareGeneratedDialect<arith::ArithDialect>();
50+
declareGeneratedDialect<affine::AffineDialect>();
51+
declareGeneratedDialect<amdgpu::AMDGPUDialect>();
52+
declareGeneratedDialect<vector::VectorDialect>();
53+
registerTransformOps<
54+
#define GET_OP_LIST
55+
#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc"
56+
>();
57+
}
58+
};
59+
} // namespace
60+
61+
#define GET_OP_CLASSES
62+
#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc"
63+
64+
void amdgpu::registerTransformDialectExtension(DialectRegistry &registry) {
65+
registry.addExtensions<AMDGPUTransformDialectExtension>();
66+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
add_mlir_dialect_library(MLIRAMDGPUTransformOps
2+
AMDGPUTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/TransformOps
6+
7+
DEPENDS
8+
MLIRAMDGPUTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRAffineDialect
12+
MLIRArithDialect
13+
MLIRIR
14+
MLIRLinalgDialect
15+
MLIRAMDGPUDialect
16+
MLIRAMDGPUTransforms
17+
MLIRParser
18+
MLIRSideEffectInterfaces
19+
MLIRSCFDialect
20+
MLIRSCFTransforms
21+
MLIRTransformDialect
22+
MLIRTransformDialectUtils
23+
MLIRVectorTransforms
24+
25+
)

mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2525
#include "mlir/Interfaces/SideEffectInterfaces.h"
2626
#include "mlir/Support/LogicalResult.h"
27-
#include "llvm/ADT/STLExtras.h"
28-
#include "llvm/Support/MathExtras.h"
2927

3028
namespace mlir {
3129
namespace amdgpu {
@@ -52,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
5250
static Value permuteVectorOffset(OpBuilder &b, Location loc,
5351
ArrayRef<Value> indices, MemRefType memrefTy,
5452
int64_t srcDim, int64_t tgtDim) {
55-
// Adjust the src index to change how often the permutation changes
56-
// if necessary.
53+
/// Adjust the src index to change how often the permutation changes
54+
/// if necessary.
5755
Value src = indices[srcDim];
5856

59-
// We only want to permute every N iterations of the target dim where N is
60-
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
57+
/// We only want to permute every N iterations of the target dim where N is
58+
/// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
6159
const int64_t permuteEveryN = std::max<int64_t>(
6260
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
6361
memrefTy.getElementTypeBitWidth()) /
@@ -83,8 +81,8 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
8381
Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
8482
srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
8583

86-
// Use the src bits to permute the target bits b[N:M] containing the
87-
// vector offset.
84+
/// Use the src bits to permute the target bits b[N:M] containing the
85+
/// vector offset.
8886
if (permuteEveryN > 1) {
8987
int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
9088
if (shlBits > 0) {
@@ -133,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
133131
writeOps.push_back(op);
134132
});
135133

136-
// Restrict to a supported set of ops. We also require at least 2D access,
137-
// although this could be relaxed.
134+
/// Restrict to a supported set of ops. We also require at least 2D access,
135+
/// although this could be relaxed.
138136
if (llvm::any_of(readOps, [](Operation *op) {
139137
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
140138
op) ||
@@ -159,15 +157,15 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
159157
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
160158
return failure();
161159

162-
// Abort if the given value has any sub-views; we do not do any alias
163-
// analysis.
160+
/// Abort if the given value has any sub-views; we do not do any alias
161+
/// analysis.
164162
bool hasSubView = false;
165163
parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
166164
if (hasSubView)
167165
return failure();
168166

169-
// Check if this is necessary given the assumption of 128b accesses:
170-
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
167+
/// Check if this is necessary given the assumption of 128b accesses:
168+
/// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
171169
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
172170
const int64_t rowsPerLine =
173171
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -177,8 +175,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
177175
if (rowsPerLine >= threadGroupSize)
178176
return failure();
179177

180-
// Get sets of operations within the function that read/write to shared
181-
// memory.
178+
/// Get sets of operations within the function that read/write to shared
179+
/// memory.
182180
SmallVector<Operation *, 16> shmReadOps;
183181
SmallVector<Operation *, 16> shmWriteOps;
184182
if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -193,7 +191,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
193191
int64_t tgtDim = memRefType.getRank() - 1;
194192
int64_t srcDim = memRefType.getRank() - 2;
195193

196-
// Transform indices for the ops writing to shared memory.
194+
/// Transform indices for the ops writing to shared memory.
197195
while (!shmWriteOps.empty()) {
198196
Operation *shmWriteOp = shmWriteOps.pop_back_val();
199197
builder.setInsertionPoint(shmWriteOp);
@@ -205,7 +203,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
205203
amdgpu::setIndices(shmWriteOp, transformedIndices);
206204
}
207205

208-
// Transform indices for the ops reading from shared memory.
206+
/// Transform indices for the ops reading from shared memory.
209207
while (!shmReadOps.empty()) {
210208
Operation *shmReadOp = shmReadOps.pop_back_val();
211209
builder.setInsertionPoint(shmReadOp);
@@ -220,6 +218,20 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
220218
return success();
221219
}
222220

221+
void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
222+
SmallVector<memref::AllocOp> shmAllocOps;
223+
funcOp.walk([&](memref::AllocOp allocOp) {
224+
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
225+
return;
226+
shmAllocOps.push_back(allocOp);
227+
});
228+
for (auto allocOp : shmAllocOps) {
229+
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
230+
allocOp.getMemref())))
231+
return;
232+
}
233+
}
234+
223235
struct OptimizeSharedMemoryPass
224236
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
225237
public:

0 commit comments

Comments
 (0)