Skip to content

Commit 066b4fc

Browse files
committed
[mlir] Update VectorToGPU to new memory space
GPU memory space have changed to new attributes. Update VectorToGPU pass to use those. Differential Revision: https://reviews.llvm.org/D142105
1 parent 9ef7ae5 commit 066b4fc

File tree

2 files changed

+73
-63
lines changed

2 files changed

+73
-63
lines changed

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,16 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
669669
return success();
670670
}
671671

672+
/// Return true if this is a shared memory memref type.
673+
static bool isSharedMemory(MemRefType type) {
674+
auto addressSpace =
675+
type.getMemorySpace().dyn_cast_or_null<gpu::AddressSpaceAttr>();
676+
if (addressSpace &&
677+
addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace())
678+
return true;
679+
return false;
680+
}
681+
672682
/// Converts a `vector.transfer_read` operation directly to either a
673683
/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
674684
/// used when converting to `nvgpu.mma.sync` operations.
@@ -683,7 +693,7 @@ convertTransferReadToLoads(vector::TransferReadOp op,
683693
return failure();
684694

685695
bool isLdMatrixCompatible =
686-
op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
696+
isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
687697
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
688698

689699
VectorType vecTy = op.getVectorType();

0 commit comments

Comments
 (0)