Skip to content

Commit ea84897

Browse files
authored
[mlir][gpu] Introduce gpu.dynamic_shared_memory Op (#71546)
While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic_shared_memory` that aims to simplify the utilization of dynamic shared memory. RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/ **Proposal from RFC** This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU. **Current way Using Dynamic Shared Memory with MLIR** Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps: - memref.global 0-sized array LLVM's NVPTX backend expects - dynamic_shared_memory_size Set the size of dynamic shared memory - memref.get_global Access the global symbol - reinterpret_cast and subview Many OPs for pointer arithmetic ``` // Step 1. Create 0-sized global symbol. Manually set the alignment memref.global "private" @dynamicShmem : memref<0xf16, 3> { alignment = 16 } func.func @main() { // Step 2. Allocate shared memory gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 { // Step 3. Access the global object %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3> // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations. %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3> %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> %7 = memref.subview %6[0, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3> %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3> // Step.5 Use "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index) "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index) gpu.terminator } ``` Let’s write the program above with that: ``` func.func @main() { gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 { %i = arith.constant 18 : index // Step 1: Obtain shared memory directly %shmem = gpu.dynamic_shared_memory : memref<?xi8, 3> %c147456 = arith.constant 147456 : index %c155648 = arith.constant 155648 : index %7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3> %8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3> // Step 2: Utilize the shared memory "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index) "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index) } } ``` This PR resolves #72513
1 parent 108380d commit ea84897

File tree

11 files changed

+388
-22
lines changed

11 files changed

+388
-22
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def GPU_Dialect : Dialect {
5252
/// Returns the numeric value used to identify the private memory address
5353
/// space.
5454
static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
55+
56+
/// Return true if the given MemRefType has an address space that matches
57+
/// with the gpu::AddressSpaceAttr attribute with value 'workgroup`.
58+
static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
59+
60+
/// Return true if the given Attribute is an gpu::AddressSpaceAttr
61+
/// attribute with value 'workgroup`.
62+
static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);
5563
}];
5664

5765
let dependentDialects = ["arith::ArithDialect"];

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,32 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
433433
let hasVerifier = 1;
434434
}
435435

436+
def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
437+
{
438+
let summary = "Get the memref for dynamic shared memory";
439+
440+
let description = [{
441+
This operation provides a memref pointer to the start of dynamic shared
442+
memory, often referred to as workgroup memory. It's important to note that
443+
this dynamic shared memory needs to be allocated at kernel launch. One can
444+
conveniently utilize `the dynamic_shared_memory_size` parameter of
445+
`gpu.launch` for this purpose.
446+
447+
Examples:
448+
```mlir
449+
%0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
450+
%1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>>
451+
to memref<32x64xf32, #gpu.address_space<workgroup>>
452+
%2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>>
453+
to memref<32x64xf32, #gpu.address_space<workgroup>>
454+
```
455+
}];
456+
let arguments = (ins);
457+
let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
458+
let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
459+
let hasVerifier = 1;
460+
}
461+
436462
def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
437463

438464
def GPU_LaunchFuncOp :GPU_Op<"launch_func", [

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
namespace mlir {
2828
namespace NVVM {
2929

30+
// Shared memory has 128-bit alignment
31+
constexpr int kSharedMemoryAlignmentBit = 128;
32+
3033
/// NVVM memory space identifiers.
3134
enum NVVMMemorySpace {
3235
/// Global memory space identifier.

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,24 @@ class SymbolTable {
103103
Nested,
104104
};
105105

106+
/// Generate a unique symbol name. Iteratively increase uniquingCounter
107+
/// and use it as a suffix for symbol names until uniqueChecker does not
108+
/// detect any conflict.
109+
template <unsigned N, typename UniqueChecker>
110+
static SmallString<N> generateSymbolName(StringRef name,
111+
UniqueChecker uniqueChecker,
112+
unsigned &uniquingCounter) {
113+
SmallString<N> nameBuffer(name);
114+
unsigned originalLength = nameBuffer.size();
115+
do {
116+
nameBuffer.resize(originalLength);
117+
nameBuffer += '_';
118+
nameBuffer += std::to_string(uniquingCounter++);
119+
} while (uniqueChecker(nameBuffer));
120+
121+
return nameBuffer;
122+
}
123+
106124
/// Returns the name of the given symbol operation, aborting if no symbol is
107125
/// present.
108126
static StringAttr getSymbolName(Operation *symbol);

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/Builders.h"
1515
#include "mlir/IR/BuiltinTypes.h"
1616
#include "llvm/ADT/SmallVectorExtras.h"
17+
#include "llvm/ADT/StringSet.h"
1718
#include "llvm/Support/FormatVariadic.h"
1819

1920
using namespace mlir;
@@ -549,6 +550,104 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
549550
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
550551
}
551552

553+
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
554+
/// or uses existing symbol.
555+
LLVM::GlobalOp
556+
getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
557+
Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
558+
const LLVMTypeConverter *typeConverter,
559+
MemRefType memrefType, unsigned alignmentBit) {
560+
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
561+
562+
FailureOr<unsigned> addressSpace =
563+
typeConverter->getMemRefAddressSpace(memrefType);
564+
if (failed(addressSpace)) {
565+
op->emitError() << "conversion of memref memory space "
566+
<< memrefType.getMemorySpace()
567+
<< " to integer address space "
568+
"failed. Consider adding memory space conversions.";
569+
}
570+
571+
// Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
572+
// LLVM::GlobalOp is suitable for shared memory, return it.
573+
llvm::StringSet<> existingGlobalNames;
574+
for (auto globalOp :
575+
moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
576+
existingGlobalNames.insert(globalOp.getSymName());
577+
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
578+
if (globalOp.getAddrSpace() == addressSpace.value() &&
579+
arrayType.getNumElements() == 0 &&
580+
globalOp.getAlignment().value_or(0) == alignmentByte) {
581+
return globalOp;
582+
}
583+
}
584+
}
585+
586+
// Step 2. Find a unique symbol name
587+
unsigned uniquingCounter = 0;
588+
SmallString<128> symName = SymbolTable::generateSymbolName<128>(
589+
"__dynamic_shmem_",
590+
[&](StringRef candidate) {
591+
return existingGlobalNames.contains(candidate);
592+
},
593+
uniquingCounter);
594+
595+
// Step 3. Generate a global op
596+
OpBuilder::InsertionGuard guard(rewriter);
597+
rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
598+
599+
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
600+
typeConverter->convertType(memrefType.getElementType()), 0);
601+
602+
return rewriter.create<LLVM::GlobalOp>(
603+
op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
604+
LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
605+
addressSpace.value());
606+
}
607+
608+
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
609+
gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
610+
ConversionPatternRewriter &rewriter) const {
611+
Location loc = op.getLoc();
612+
MemRefType memrefType = op.getResultMemref().getType();
613+
Type elementType = typeConverter->convertType(memrefType.getElementType());
614+
615+
// Step 1: Generate a memref<0xi8> type
616+
MemRefLayoutAttrInterface layout = {};
617+
auto memrefType0sz =
618+
MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
619+
620+
// Step 2: Generate a global symbol or existing for the dynamic shared
621+
// memory with memref<0xi8> type
622+
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
623+
LLVM::GlobalOp shmemOp = {};
624+
Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
625+
shmemOp = getDynamicSharedMemorySymbol(
626+
rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
627+
628+
// Step 3. Get address of the global symbol
629+
OpBuilder::InsertionGuard guard(rewriter);
630+
rewriter.setInsertionPoint(op);
631+
auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
632+
Type baseType = basePtr->getResultTypes().front();
633+
634+
// Step 4. Generate GEP using offsets
635+
SmallVector<LLVM::GEPArg> gepArgs = {0};
636+
Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
637+
basePtr, gepArgs);
638+
// Step 5. Create a memref descriptor
639+
SmallVector<Value> shape, strides;
640+
Value sizeBytes;
641+
getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
642+
sizeBytes);
643+
auto memRefDescriptor = this->createMemRefDescriptor(
644+
loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
645+
646+
// Step 5. Replace the op with memref descriptor
647+
rewriter.replaceOp(op, {memRefDescriptor});
648+
return success();
649+
}
650+
552651
void mlir::populateGpuMemorySpaceAttributeConversions(
553652
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
554653
typeConverter.addTypeAttributeConversion(

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@
1414

1515
namespace mlir {
1616

17+
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
18+
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
19+
/// a memref descriptor with these values and return it.
20+
struct GPUDynamicSharedMemoryOpLowering
21+
: public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
22+
using ConvertOpToLLVMPattern<
23+
gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
24+
GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
25+
unsigned alignmentBit = 0)
26+
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
27+
alignmentBit(alignmentBit) {}
28+
29+
LogicalResult
30+
matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
31+
ConversionPatternRewriter &rewriter) const override;
32+
33+
private:
34+
// Alignment bit
35+
unsigned alignmentBit;
36+
};
37+
1738
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
1839
GPUFuncOpLowering(const LLVMTypeConverter &converter,
1940
unsigned allocaAddrSpace, unsigned workgroupAddrSpace,

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
325325
GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
326326
converter);
327327

328+
patterns.add<GPUDynamicSharedMemoryOpLowering>(
329+
converter, NVVM::kSharedMemoryAlignmentBit);
330+
328331
// Explicitly drop memory space when lowering private memory
329332
// attributions since NVVM models it as `alloca`s in the default
330333
// memory space and does not support `alloca`s with addrspace(5).

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Matchers.h"
2424
#include "mlir/IR/OpImplementation.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/IR/SymbolTable.h"
2627
#include "mlir/IR/TypeUtilities.h"
2728
#include "mlir/Interfaces/FunctionImplementation.h"
2829
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -164,17 +165,18 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
164165
// GPUDialect
165166
//===----------------------------------------------------------------------===//
166167

167-
/// GPU memory space identifiers.
168-
enum GPUMemorySpace {
169-
/// Generic memory space identifier.
170-
kGenericMemorySpace = 0,
171-
172-
/// Global memory space identifier.
173-
kGlobalMemorySpace = 1,
168+
bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
169+
if (!memorySpace)
170+
return false;
171+
if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
172+
return gpuAttr.getValue() == getWorkgroupAddressSpace();
173+
return false;
174+
}
174175

175-
/// Shared memory space identifier.
176-
kSharedMemorySpace = 3
177-
};
176+
bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
177+
Attribute memorySpace = type.getMemorySpace();
178+
return isWorkgroupMemoryAddressSpace(memorySpace);
179+
}
178180

179181
bool GPUDialect::isKernel(Operation *op) {
180182
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -2047,6 +2049,28 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
20472049
return success();
20482050
}
20492051

2052+
//===----------------------------------------------------------------------===//
2053+
// DynamicSharedMemoryOp
2054+
//===----------------------------------------------------------------------===//
2055+
2056+
LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2057+
if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2058+
return emitOpError() << "must be inside an op with symbol table";
2059+
2060+
MemRefType memrefType = getResultMemref().getType();
2061+
// Check address space
2062+
if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2063+
return emitOpError() << "address space must be "
2064+
<< gpu::AddressSpaceAttr::getMnemonic() << "<"
2065+
<< stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2066+
}
2067+
if (memrefType.hasStaticShape()) {
2068+
return emitOpError() << "result memref type must be memref<?xi8, "
2069+
"#gpu.address_space<workgroup>>";
2070+
}
2071+
return success();
2072+
}
2073+
20502074
//===----------------------------------------------------------------------===//
20512075
// GPU target options
20522076
//===----------------------------------------------------------------------===//

mlir/lib/IR/SymbolTable.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -200,20 +200,16 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
200200
// If the symbol was already in the table, also return.
201201
if (symbolTable.lookup(name) == symbol)
202202
return name;
203-
// If a conflict was detected, then the symbol will not have been added to
204-
// the symbol table. Try suffixes until we get to a unique name that works.
205-
SmallString<128> nameBuffer(name.getValue());
206-
unsigned originalLength = nameBuffer.size();
207203

208204
MLIRContext *context = symbol->getContext();
209-
210-
// Iteratively try suffixes until we find one that isn't used.
211-
do {
212-
nameBuffer.resize(originalLength);
213-
nameBuffer += '_';
214-
nameBuffer += std::to_string(uniquingCounter++);
215-
} while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
216-
.second);
205+
SmallString<128> nameBuffer = generateSymbolName<128>(
206+
name.getValue(),
207+
[&](StringRef candidate) {
208+
return !symbolTable
209+
.insert({StringAttr::get(context, candidate), symbol})
210+
.second;
211+
},
212+
uniquingCounter);
217213
setSymbolName(symbol, nameBuffer);
218214
return getSymbolName(symbol);
219215
}

0 commit comments

Comments
 (0)