Skip to content

[mlir][amdgpu] Shared memory access optimization pass #75627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 19, 2024

Conversation

erman-gurses
Copy link
Contributor

@erman-gurses erman-gurses commented Dec 15, 2023

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized accesses
common to loading matrix multiplication operands to/from shared memory.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-backend-amdgpu

Author: None (erman-gurses)

Changes

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized accesses
common to loading matrix multiplication operands to/from shared memory.


Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75627.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+27)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+8)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+54)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h (+21)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+15)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+252)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp (+48)
  • (added) mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (+57)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd732..324c656f47599e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect {
     "gpu::GPUDialect"
   ];
   let useDefaultAttributePrinterParser = 1;
+
+  let extraClassDeclaration = [{
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isSharedMemoryAddressSpace(Attribute type);
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in global memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in private memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kPrivateMemoryAddressSpace = 2;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in shared memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kSharedMemoryAddressSpace = 3;
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198a..752078cd6930e3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,6 +21,10 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+
+/// Create a pass to optimize shared memory reads and writes.
+std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index e6b27aa842dfcd..1c12ca98271127 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
                         "Chipset that these operations will run on">];
 }
 
+def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
+  let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+  let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..140bc12deed690
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -0,0 +1,54 @@
+//===- Transforms.h - AMDGPU Dialect transformations --------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions that assist transformations for the amdgpu
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class RewriterBase;
+
+namespace amdgpu {
+
+///
+/// Passes
+///
+
+/// Optimizes vectorized accesses to a shared memory buffer specified by
+/// memrefValue. This transformation assumes the following:
+/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
+/// 2) The function will fail precondition checks if any subviews are
+/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
+/// through `memrefValue` directly.
+///
+/// Shared memory bank conflicts occur when multiple threads attempt to read or
+/// write locations assigned to the same shared memory bank. For `2^N` byte
+/// vectorized accesses, we need to be concerned with conflicts among threads
+/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
+/// changes any indexed memory access (vector.load, memref.load, etc)
+/// such that the final dimension's index value is permuted such that
+/// `newColIndex = oldColIndex % vectorSize +
+/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
+/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
+/// function that depends on the row Index. The permutation function is chosen
+/// to ensure that sequential distributed+vectorized reads/writes down a single
+/// dimension of the memref have minimal conflicts.
+mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                       Value memrefValue);
+
+} // namespace amdgpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
new file mode 100644
index 00000000000000..bee3af1914feef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace amdgpu {
+
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
+
+} // namespace amdgpu
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..4e72fbf56b80a4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
       >();
 }
 
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isSharedMemoryAddressSpace(memorySpace);
+}
+
 //===----------------------------------------------------------------------===//
 // 8-bit float ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index e11b6cc88bf224..a1a91270bc55c4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
+  OptimizeSharedMemory.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
new file mode 100644
index 00000000000000..0a2f04f4e6487f
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -0,0 +1,252 @@
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to optimize accesses to shared memory.
+// It is inspired by
+// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace amdgpu {
+#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace amdgpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// The size of a shared memory line according to AMD documentation.
+/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
+constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+/// We optimize for 64bit accesses, but this can be made an argument in the
+/// future.
+constexpr int64_t kDefaultVectorSizeBits = 64;
+
+/// Uses `srcIndexValue` to permute `tgtIndexValue` via
+/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
+///               floordiv(tgtIdxVal,vectorSize)))
+///            + tgtIdxVal % vectorSize`
+/// This is done using an optimized sequence of `arith` operations.
+static Value permuteVectorOffset(OpBuilder &b, Location loc,
+                                 ArrayRef<Value> indices, MemRefType memrefTy,
+                                 int64_t srcDim, int64_t tgtDim) {
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
+  Value src = indices[srcDim];
+
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  const int64_t permuteEveryN = std::max<int64_t>(
+      1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+                                        memrefTy.getElementTypeBitWidth()) /
+                                       8));
+
+  // clang-format off
+  // Index bit representation (b0 = least significant bit) for dim(1)
+  // of a `memref<?x?xDT>` is as follows:
+  // N := log2(128/elementSizeBits)
+  // M := log2(dimSize(1))
+  // then
+  // bits[0:N] = sub-vector element offset
+  // bits[N:M] = vector index
+  // clang-format on
+  int64_t n =
+      llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
+
+  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
+  int64_t mask = (1LL << (m - n)) - 1;
+  if (permuteEveryN > 1)
+    mask = mask << llvm::Log2_64(permuteEveryN);
+  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
+  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+
+  // Use the src bits to permute the target bits b[N:M] containing the
+  // vector offset.
+  if (permuteEveryN > 1) {
+    int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
+    if (shlBits > 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+      srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+    } else if (shlBits < 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+      srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
+    }
+  } else {
+    Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+    srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+  }
+
+  Value permutedVectorIdx =
+      b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+  return permutedVectorIdx;
+}
+
+static void transformIndices(OpBuilder &builder, Location loc,
+                             SmallVector<Value, 4> &indices,
+                             MemRefType memrefTy, int64_t srcDim,
+                             int64_t tgtDim) {
+  indices[tgtDim] =
+      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+}
+
+/// Return all operations within `parentOp` that read from or write to
+/// `shmMemRef`.
+static LogicalResult
+getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
+                      SmallVector<Operation *, 16> &readOps,
+                      SmallVector<Operation *, 16> &writeOps) {
+  parentOp->walk([&](Operation *op) {
+    MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
+    if (!iface)
+      return;
+    std::optional<MemoryEffects::EffectInstance> effect =
+        iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
+    if (effect) {
+      readOps.push_back(op);
+      return;
+    }
+    effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
+    if (effect)
+      writeOps.push_back(op);
+  });
+
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
+  if (llvm::any_of(readOps, [](Operation *op) {
+        return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+  if (llvm::any_of(writeOps, [](Operation *op) {
+        return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+
+  return success();
+}
+
+mlir::LogicalResult
+mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                 Value memrefValue) {
+  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
+  if (!memRefType ||
+      !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
+    return failure();
+
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
+  bool hasSubView = false;
+  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
+  if (hasSubView)
+    return failure();
+
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+  const int64_t rowsPerLine =
+      (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+      rowSize;
+  const int64_t threadGroupSize =
+      1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+  if (rowsPerLine >= threadGroupSize)
+    return failure();
+
+  // Get sets of operations within the function that read/write to shared
+  // memory.
+  SmallVector<Operation *, 16> shmReadOps;
+  SmallVector<Operation *, 16> shmWriteOps;
+  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
+                                   shmWriteOps)))
+    return failure();
+
+  if (shmReadOps.empty() || shmWriteOps.empty())
+    return failure();
+
+  OpBuilder builder(parentOp->getContext());
+
+  int64_t tgtDim = memRefType.getRank() - 1;
+  int64_t srcDim = memRefType.getRank() - 2;
+
+  // Transform indices for the ops writing to shared memory.
+  while (!shmWriteOps.empty()) {
+    Operation *shmWriteOp = shmWriteOps.back();
+    shmWriteOps.pop_back();
+    builder.setInsertionPoint(shmWriteOp);
+
+    auto indices = amdgpu::getIndices(shmWriteOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmWriteOp, transformedIndices);
+  }
+
+  // Transform indices for the ops reading from shared memory.
+  while (!shmReadOps.empty()) {
+    Operation *shmReadOp = shmReadOps.back();
+    shmReadOps.pop_back();
+    builder.setInsertionPoint(shmReadOp);
+
+    auto indices = amdgpu::getIndices(shmReadOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmReadOp, transformedIndices);
+  }
+
+  return success();
+}
+
+namespace {
+class OptimizeSharedMemoryPass
+    : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+public:
+  OptimizeSharedMemoryPass() = default;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    SmallVector<memref::AllocOp> shmAllocOps;
+    op->walk([&](memref::AllocOp allocOp) {
+      if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
+              allocOp.getType()))
+        return;
+      shmAllocOps.push_back(allocOp);
+    });
+    for (auto allocOp : shmAllocOps) {
+      if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
+                                                    allocOp.getMemref())))
+        return;
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
+  return std::make_unique<OptimizeSharedMemoryPass>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..a1dc6cf70e7bf8
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Transform utilities ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+Operation::operand_range amdgpu::getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+    return vectorReadOp.getIndices();
+  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+    return vectorStoreOp.getIndices();
+  if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+    return transferReadOp.getIndices();
+  if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+    return transferWriteOp.getIndices();
+  llvm_unreachable("unsupported op type");
+}
+
+void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndicesMutable().assign(indices);
+  if (auto storeOp = dyn_cast<memref:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir-gpu

Author: None (erman-gurses)

Changes

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized accesses
common to loading matrix multiplication operands to/from shared memory.


Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75627.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+27)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+8)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+54)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h (+21)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+15)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+252)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp (+48)
  • (added) mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (+57)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd732..324c656f47599e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect {
     "gpu::GPUDialect"
   ];
   let useDefaultAttributePrinterParser = 1;
+
+  let extraClassDeclaration = [{
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isSharedMemoryAddressSpace(Attribute type);
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in global memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in private memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kPrivateMemoryAddressSpace = 2;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in shared memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kSharedMemoryAddressSpace = 3;
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198a..752078cd6930e3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,6 +21,10 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+
+/// Create a pass to optimize shared memory reads and writes.
+std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index e6b27aa842dfcd..1c12ca98271127 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
                         "Chipset that these operations will run on">];
 }
 
+def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
+  let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+  let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..140bc12deed690
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -0,0 +1,54 @@
+//===- Transforms.h - AMDGPU Dialect transformations --------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions that assist transformations for the amdgpu
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class RewriterBase;
+
+namespace amdgpu {
+
+///
+/// Passes
+///
+
+/// Optimizes vectorized accesses to a shared memory buffer specified by
+/// memrefValue. This transformation assumes the following:
+/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
+/// 2) The function will fail precondition checks if any subviews are
+/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
+/// through `memrefValue` directly.
+///
+/// Shared memory bank conflicts occur when multiple threads attempt to read or
+/// write locations assigned to the same shared memory bank. For `2^N` byte
+/// vectorized accesses, we need to be concerned with conflicts among threads
+/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
+/// changes any indexed memory access (vector.load, memref.load, etc)
+/// such that the final dimension's index value is permuted such that
+/// `newColIndex = oldColIndex % vectorSize +
+/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
+/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
+/// function that depends on the row Index. The permutation function is chosen
+/// to ensure that sequential distributed+vectorized reads/writes down a single
+/// dimension of the memref have minimal conflicts.
+mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                       Value memrefValue);
+
+} // namespace amdgpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
new file mode 100644
index 00000000000000..bee3af1914feef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace amdgpu {
+
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
+
+} // namespace amdgpu
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..4e72fbf56b80a4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
       >();
 }
 
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isSharedMemoryAddressSpace(memorySpace);
+}
+
 //===----------------------------------------------------------------------===//
 // 8-bit float ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index e11b6cc88bf224..a1a91270bc55c4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
+  OptimizeSharedMemory.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
new file mode 100644
index 00000000000000..0a2f04f4e6487f
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -0,0 +1,252 @@
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to optimize accesses to shared memory.
+// It is inspired by
+// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace amdgpu {
+#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace amdgpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// The size of a shared memory line according to AMD documentation.
+/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
+constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+/// We optimize for 64bit accesses, but this can be made an argument in the
+/// future.
+constexpr int64_t kDefaultVectorSizeBits = 64;
+
+/// Uses `srcIndexValue` to permute `tgtIndexValue` via
+/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
+///               floordiv(tgtIdxVal,vectorSize)))
+///            + tgtIdxVal % vectorSize`
+/// This is done using an optimized sequence of `arith` operations.
+static Value permuteVectorOffset(OpBuilder &b, Location loc,
+                                 ArrayRef<Value> indices, MemRefType memrefTy,
+                                 int64_t srcDim, int64_t tgtDim) {
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
+  Value src = indices[srcDim];
+
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  const int64_t permuteEveryN = std::max<int64_t>(
+      1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+                                        memrefTy.getElementTypeBitWidth()) /
+                                       8));
+
+  // clang-format off
+  // Index bit representation (b0 = least significant bit) for dim(1)
+  // of a `memref<?x?xDT>` is as follows:
+  // N := log2(128/elementSizeBits)
+  // M := log2(dimSize(1))
+  // then
+  // bits[0:N] = sub-vector element offset
+  // bits[N:M] = vector index
+  // clang-format on
+  int64_t n =
+      llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
+
+  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
+  int64_t mask = (1LL << (m - n)) - 1;
+  if (permuteEveryN > 1)
+    mask = mask << llvm::Log2_64(permuteEveryN);
+  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
+  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+
+  // Use the src bits to permute the target bits b[N:M] containing the
+  // vector offset.
+  if (permuteEveryN > 1) {
+    int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
+    if (shlBits > 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+      srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+    } else if (shlBits < 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+      srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
+    }
+  } else {
+    Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+    srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+  }
+
+  Value permutedVectorIdx =
+      b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+  return permutedVectorIdx;
+}
+
+static void transformIndices(OpBuilder &builder, Location loc,
+                             SmallVector<Value, 4> &indices,
+                             MemRefType memrefTy, int64_t srcDim,
+                             int64_t tgtDim) {
+  indices[tgtDim] =
+      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+}
+
+/// Return all operations within `parentOp` that read from or write to
+/// `shmMemRef`.
+static LogicalResult
+getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
+                      SmallVector<Operation *, 16> &readOps,
+                      SmallVector<Operation *, 16> &writeOps) {
+  parentOp->walk([&](Operation *op) {
+    MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
+    if (!iface)
+      return;
+    std::optional<MemoryEffects::EffectInstance> effect =
+        iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
+    if (effect) {
+      readOps.push_back(op);
+      return;
+    }
+    effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
+    if (effect)
+      writeOps.push_back(op);
+  });
+
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
+  if (llvm::any_of(readOps, [](Operation *op) {
+        return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+  if (llvm::any_of(writeOps, [](Operation *op) {
+        return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+
+  return success();
+}
+
+mlir::LogicalResult
+mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                 Value memrefValue) {
+  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
+  if (!memRefType ||
+      !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
+    return failure();
+
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
+  bool hasSubView = false;
+  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
+  if (hasSubView)
+    return failure();
+
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+  const int64_t rowsPerLine =
+      (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+      rowSize;
+  const int64_t threadGroupSize =
+      1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+  if (rowsPerLine >= threadGroupSize)
+    return failure();
+
+  // Get sets of operations within the function that read/write to shared
+  // memory.
+  SmallVector<Operation *, 16> shmReadOps;
+  SmallVector<Operation *, 16> shmWriteOps;
+  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
+                                   shmWriteOps)))
+    return failure();
+
+  if (shmReadOps.empty() || shmWriteOps.empty())
+    return failure();
+
+  OpBuilder builder(parentOp->getContext());
+
+  int64_t tgtDim = memRefType.getRank() - 1;
+  int64_t srcDim = memRefType.getRank() - 2;
+
+  // Transform indices for the ops writing to shared memory.
+  while (!shmWriteOps.empty()) {
+    Operation *shmWriteOp = shmWriteOps.back();
+    shmWriteOps.pop_back();
+    builder.setInsertionPoint(shmWriteOp);
+
+    auto indices = amdgpu::getIndices(shmWriteOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmWriteOp, transformedIndices);
+  }
+
+  // Transform indices for the ops reading from shared memory.
+  while (!shmReadOps.empty()) {
+    Operation *shmReadOp = shmReadOps.back();
+    shmReadOps.pop_back();
+    builder.setInsertionPoint(shmReadOp);
+
+    auto indices = amdgpu::getIndices(shmReadOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmReadOp, transformedIndices);
+  }
+
+  return success();
+}
+
+namespace {
+class OptimizeSharedMemoryPass
+    : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+public:
+  OptimizeSharedMemoryPass() = default;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    SmallVector<memref::AllocOp> shmAllocOps;
+    op->walk([&](memref::AllocOp allocOp) {
+      if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
+              allocOp.getType()))
+        return;
+      shmAllocOps.push_back(allocOp);
+    });
+    for (auto allocOp : shmAllocOps) {
+      if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
+                                                    allocOp.getMemref())))
+        return;
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
+  return std::make_unique<OptimizeSharedMemoryPass>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..a1dc6cf70e7bf8
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Transform utilities ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+Operation::operand_range amdgpu::getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+    return vectorReadOp.getIndices();
+  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+    return vectorStoreOp.getIndices();
+  if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+    return transferReadOp.getIndices();
+  if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+    return transferWriteOp.getIndices();
+  llvm_unreachable("unsupported op type");
+}
+
+void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndicesMutable().assign(indices);
+  if (auto storeOp = dyn_cast<memref:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir

Author: None (erman-gurses)

Changes

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized accesses
common to loading matrix multiplication operands to/from shared memory.


Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75627.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+27)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+8)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+54)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h (+21)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+15)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+252)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp (+48)
  • (added) mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (+57)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd732..324c656f47599e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect {
     "gpu::GPUDialect"
   ];
   let useDefaultAttributePrinterParser = 1;
+
+  let extraClassDeclaration = [{
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isSharedMemoryAddressSpace(Attribute type);
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in global memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in private memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kPrivateMemoryAddressSpace = 2;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in shared memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kSharedMemoryAddressSpace = 3;
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198a..752078cd6930e3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,6 +21,10 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+
+/// Create a pass to optimize shared memory reads and writes.
+std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index e6b27aa842dfcd..1c12ca98271127 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
                         "Chipset that these operations will run on">];
 }
 
+def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
+  let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+  let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..140bc12deed690
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -0,0 +1,54 @@
+//===- Transforms.h - AMDGPU Dialect transformations --------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions that assist transformations for the amdgpu
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class RewriterBase;
+
+namespace amdgpu {
+
+///
+/// Passes
+///
+
+/// Optimizes vectorized accesses to a shared memory buffer specified by
+/// memrefValue. This transformation assumes the following:
+/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
+/// 2) The function will fail precondition checks if any subviews are
+/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
+/// through `memrefValue` directly.
+///
+/// Shared memory bank conflicts occur when multiple threads attempt to read or
+/// write locations assigned to the same shared memory bank. For `2^N` byte
+/// vectorized accesses, we need to be concerned with conflicts among threads
+/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
+/// changes any indexed memory access (vector.load, memref.load, etc)
+/// such that the final dimension's index value is permuted such that
+/// `newColIndex = oldColIndex % vectorSize +
+/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
+/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
+/// function that depends on the row Index. The permutation function is chosen
+/// to ensure that sequential distributed+vectorized reads/writes down a single
+/// dimension of the memref have minimal conflicts.
+mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                       Value memrefValue);
+
+} // namespace amdgpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
new file mode 100644
index 00000000000000..bee3af1914feef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace amdgpu {
+
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
+
+} // namespace amdgpu
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..4e72fbf56b80a4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
       >();
 }
 
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isSharedMemoryAddressSpace(memorySpace);
+}
+
 //===----------------------------------------------------------------------===//
 // 8-bit float ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index e11b6cc88bf224..a1a91270bc55c4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
+  OptimizeSharedMemory.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
new file mode 100644
index 00000000000000..0a2f04f4e6487f
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -0,0 +1,252 @@
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to optimize accesses to shared memory.
+// It is inspired by
+// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace amdgpu {
+#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace amdgpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// The size of a shared memory line according to AMD documentation.
+/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
+constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+/// We optimize for 64bit accesses, but this can be made an argument in the
+/// future.
+constexpr int64_t kDefaultVectorSizeBits = 64;
+
+/// Uses `srcIndexValue` to permute `tgtIndexValue` via
+/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
+///               floordiv(tgtIdxVal,vectorSize)))
+///            + tgtIdxVal % vectorSize`
+/// This is done using an optimized sequence of `arith` operations.
+static Value permuteVectorOffset(OpBuilder &b, Location loc,
+                                 ArrayRef<Value> indices, MemRefType memrefTy,
+                                 int64_t srcDim, int64_t tgtDim) {
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
+  Value src = indices[srcDim];
+
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  const int64_t permuteEveryN = std::max<int64_t>(
+      1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+                                        memrefTy.getElementTypeBitWidth()) /
+                                       8));
+
+  // clang-format off
+  // Index bit representation (b0 = least significant bit) for dim(1)
+  // of a `memref<?x?xDT>` is as follows:
+  // N := log2(128/elementSizeBits)
+  // M := log2(dimSize(1))
+  // then
+  // bits[0:N] = sub-vector element offset
+  // bits[N:M] = vector index
+  // clang-format on
+  int64_t n =
+      llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
+
+  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
+  int64_t mask = (1LL << (m - n)) - 1;
+  if (permuteEveryN > 1)
+    mask = mask << llvm::Log2_64(permuteEveryN);
+  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
+  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+
+  // Use the src bits to permute the target bits b[N:M] containing the
+  // vector offset.
+  if (permuteEveryN > 1) {
+    int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
+    if (shlBits > 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+      srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+    } else if (shlBits < 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+      srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
+    }
+  } else {
+    Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+    srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+  }
+
+  Value permutedVectorIdx =
+      b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+  return permutedVectorIdx;
+}
+
+static void transformIndices(OpBuilder &builder, Location loc,
+                             SmallVector<Value, 4> &indices,
+                             MemRefType memrefTy, int64_t srcDim,
+                             int64_t tgtDim) {
+  indices[tgtDim] =
+      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+}
+
+/// Return all operations within `parentOp` that read from or write to
+/// `shmMemRef`.
+static LogicalResult
+getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
+                      SmallVector<Operation *, 16> &readOps,
+                      SmallVector<Operation *, 16> &writeOps) {
+  parentOp->walk([&](Operation *op) {
+    MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
+    if (!iface)
+      return;
+    std::optional<MemoryEffects::EffectInstance> effect =
+        iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
+    if (effect) {
+      readOps.push_back(op);
+      return;
+    }
+    effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
+    if (effect)
+      writeOps.push_back(op);
+  });
+
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
+  if (llvm::any_of(readOps, [](Operation *op) {
+        return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+  if (llvm::any_of(writeOps, [](Operation *op) {
+        return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+
+  return success();
+}
+
+mlir::LogicalResult
+mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                 Value memrefValue) {
+  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
+  if (!memRefType ||
+      !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
+    return failure();
+
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
+  bool hasSubView = false;
+  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
+  if (hasSubView)
+    return failure();
+
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+  const int64_t rowsPerLine =
+      (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+      rowSize;
+  const int64_t threadGroupSize =
+      1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+  if (rowsPerLine >= threadGroupSize)
+    return failure();
+
+  // Get sets of operations within the function that read/write to shared
+  // memory.
+  SmallVector<Operation *, 16> shmReadOps;
+  SmallVector<Operation *, 16> shmWriteOps;
+  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
+                                   shmWriteOps)))
+    return failure();
+
+  if (shmReadOps.empty() || shmWriteOps.empty())
+    return failure();
+
+  OpBuilder builder(parentOp->getContext());
+
+  int64_t tgtDim = memRefType.getRank() - 1;
+  int64_t srcDim = memRefType.getRank() - 2;
+
+  // Transform indices for the ops writing to shared memory.
+  while (!shmWriteOps.empty()) {
+    Operation *shmWriteOp = shmWriteOps.back();
+    shmWriteOps.pop_back();
+    builder.setInsertionPoint(shmWriteOp);
+
+    auto indices = amdgpu::getIndices(shmWriteOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmWriteOp, transformedIndices);
+  }
+
+  // Transform indices for the ops reading from shared memory.
+  while (!shmReadOps.empty()) {
+    Operation *shmReadOp = shmReadOps.back();
+    shmReadOps.pop_back();
+    builder.setInsertionPoint(shmReadOp);
+
+    auto indices = amdgpu::getIndices(shmReadOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmReadOp, transformedIndices);
+  }
+
+  return success();
+}
+
+namespace {
+class OptimizeSharedMemoryPass
+    : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+public:
+  OptimizeSharedMemoryPass() = default;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    SmallVector<memref::AllocOp> shmAllocOps;
+    op->walk([&](memref::AllocOp allocOp) {
+      if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
+              allocOp.getType()))
+        return;
+      shmAllocOps.push_back(allocOp);
+    });
+    for (auto allocOp : shmAllocOps) {
+      if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
+                                                    allocOp.getMemref())))
+        return;
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
+  return std::make_unique<OptimizeSharedMemoryPass>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..a1dc6cf70e7bf8
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Transform utilities ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+Operation::operand_range amdgpu::getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+    return vectorReadOp.getIndices();
+  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+    return vectorStoreOp.getIndices();
+  if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+    return transferReadOp.getIndices();
+  if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+    return transferWriteOp.getIndices();
+  llvm_unreachable("unsupported op type");
+}
+
+void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndicesMutable().assign(indices);
+  if (auto storeOp = dyn_cast<memref:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir-amdgpu

Author: None (erman-gurses)

Changes

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized accesses
common to loading matrix multiplication operands to/from shared memory.


Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75627.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+27)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+8)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+54)
  • (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h (+21)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+15)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+252)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp (+48)
  • (added) mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (+57)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd732..324c656f47599e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect {
     "gpu::GPUDialect"
   ];
   let useDefaultAttributePrinterParser = 1;
+
+  let extraClassDeclaration = [{
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isSharedMemoryAddressSpace(Attribute type);
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in global memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in private memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kPrivateMemoryAddressSpace = 2;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in shared memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kSharedMemoryAddressSpace = 3;
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198a..752078cd6930e3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,6 +21,10 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+
+/// Create a pass to optimize shared memory reads and writes.
+std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index e6b27aa842dfcd..1c12ca98271127 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
                         "Chipset that these operations will run on">];
 }
 
+def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
+  let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+  let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..140bc12deed690
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -0,0 +1,54 @@
+//===- Transforms.h - AMDGPU Dialect transformations --------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions that assist transformations for the amdgpu
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class RewriterBase;
+
+namespace amdgpu {
+
+///
+/// Passes
+///
+
+/// Optimizes vectorized accesses to a shared memory buffer specified by
+/// memrefValue. This transformation assumes the following:
+/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
+/// 2) The function will fail precondition checks if any subviews are
+/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
+/// through `memrefValue` directly.
+///
+/// Shared memory bank conflicts occur when multiple threads attempt to read or
+/// write locations assigned to the same shared memory bank. For `2^N` byte
+/// vectorized accesses, we need to be concerned with conflicts among threads
+/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
+/// changes any indexed memory access (vector.load, memref.load, etc)
+/// such that the final dimension's index value is permuted such that
+/// `newColIndex = oldColIndex % vectorSize +
+/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
+/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
+/// function that depends on the row Index. The permutation function is chosen
+/// to ensure that sequential distributed+vectorized reads/writes down a single
+/// dimension of the memref have minimal conflicts.
+mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                       Value memrefValue);
+
+} // namespace amdgpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
new file mode 100644
index 00000000000000..bee3af1914feef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace amdgpu {
+
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
+
+} // namespace amdgpu
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..4e72fbf56b80a4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
       >();
 }
 
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isSharedMemoryAddressSpace(memorySpace);
+}
+
 //===----------------------------------------------------------------------===//
 // 8-bit float ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index e11b6cc88bf224..a1a91270bc55c4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
+  OptimizeSharedMemory.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
new file mode 100644
index 00000000000000..0a2f04f4e6487f
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -0,0 +1,252 @@
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to optimize accesses to shared memory.
+// It is inspired by
+// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace amdgpu {
+#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace amdgpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// The size of a shared memory line according to AMD documentation.
+/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
+constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+/// We optimize for 64bit accesses, but this can be made an argument in the
+/// future.
+constexpr int64_t kDefaultVectorSizeBits = 64;
+
+/// Uses `srcIndexValue` to permute `tgtIndexValue` via
+/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
+///               floordiv(tgtIdxVal,vectorSize)))
+///            + tgtIdxVal % vectorSize`
+/// This is done using an optimized sequence of `arith` operations.
+static Value permuteVectorOffset(OpBuilder &b, Location loc,
+                                 ArrayRef<Value> indices, MemRefType memrefTy,
+                                 int64_t srcDim, int64_t tgtDim) {
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
+  Value src = indices[srcDim];
+
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  const int64_t permuteEveryN = std::max<int64_t>(
+      1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+                                        memrefTy.getElementTypeBitWidth()) /
+                                       8));
+
+  // clang-format off
+  // Index bit representation (b0 = least significant bit) for dim(1)
+  // of a `memref<?x?xDT>` is as follows:
+  // N := log2(128/elementSizeBits)
+  // M := log2(dimSize(1))
+  // then
+  // bits[0:N] = sub-vector element offset
+  // bits[N:M] = vector index
+  // clang-format on
+  int64_t n =
+      llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
+
+  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
+  int64_t mask = (1LL << (m - n)) - 1;
+  if (permuteEveryN > 1)
+    mask = mask << llvm::Log2_64(permuteEveryN);
+  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
+  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+
+  // Use the src bits to permute the target bits b[N:M] containing the
+  // vector offset.
+  if (permuteEveryN > 1) {
+    int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
+    if (shlBits > 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+      srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+    } else if (shlBits < 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+      srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
+    }
+  } else {
+    Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+    srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+  }
+
+  Value permutedVectorIdx =
+      b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+  return permutedVectorIdx;
+}
+
+static void transformIndices(OpBuilder &builder, Location loc,
+                             SmallVector<Value, 4> &indices,
+                             MemRefType memrefTy, int64_t srcDim,
+                             int64_t tgtDim) {
+  indices[tgtDim] =
+      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+}
+
+/// Return all operations within `parentOp` that read from or write to
+/// `shmMemRef`.
+static LogicalResult
+getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
+                      SmallVector<Operation *, 16> &readOps,
+                      SmallVector<Operation *, 16> &writeOps) {
+  parentOp->walk([&](Operation *op) {
+    MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
+    if (!iface)
+      return;
+    std::optional<MemoryEffects::EffectInstance> effect =
+        iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
+    if (effect) {
+      readOps.push_back(op);
+      return;
+    }
+    effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
+    if (effect)
+      writeOps.push_back(op);
+  });
+
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
+  if (llvm::any_of(readOps, [](Operation *op) {
+        return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+  if (llvm::any_of(writeOps, [](Operation *op) {
+        return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+
+  return success();
+}
+
+mlir::LogicalResult
+mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                 Value memrefValue) {
+  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
+  if (!memRefType ||
+      !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
+    return failure();
+
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
+  bool hasSubView = false;
+  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
+  if (hasSubView)
+    return failure();
+
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+  const int64_t rowsPerLine =
+      (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+      rowSize;
+  const int64_t threadGroupSize =
+      1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+  if (rowsPerLine >= threadGroupSize)
+    return failure();
+
+  // Get sets of operations within the function that read/write to shared
+  // memory.
+  SmallVector<Operation *, 16> shmReadOps;
+  SmallVector<Operation *, 16> shmWriteOps;
+  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
+                                   shmWriteOps)))
+    return failure();
+
+  if (shmReadOps.empty() || shmWriteOps.empty())
+    return failure();
+
+  OpBuilder builder(parentOp->getContext());
+
+  int64_t tgtDim = memRefType.getRank() - 1;
+  int64_t srcDim = memRefType.getRank() - 2;
+
+  // Transform indices for the ops writing to shared memory.
+  while (!shmWriteOps.empty()) {
+    Operation *shmWriteOp = shmWriteOps.back();
+    shmWriteOps.pop_back();
+    builder.setInsertionPoint(shmWriteOp);
+
+    auto indices = amdgpu::getIndices(shmWriteOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmWriteOp, transformedIndices);
+  }
+
+  // Transform indices for the ops reading from shared memory.
+  while (!shmReadOps.empty()) {
+    Operation *shmReadOp = shmReadOps.back();
+    shmReadOps.pop_back();
+    builder.setInsertionPoint(shmReadOp);
+
+    auto indices = amdgpu::getIndices(shmReadOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmReadOp, transformedIndices);
+  }
+
+  return success();
+}
+
+namespace {
+class OptimizeSharedMemoryPass
+    : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+public:
+  OptimizeSharedMemoryPass() = default;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    SmallVector<memref::AllocOp> shmAllocOps;
+    op->walk([&](memref::AllocOp allocOp) {
+      if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
+              allocOp.getType()))
+        return;
+      shmAllocOps.push_back(allocOp);
+    });
+    for (auto allocOp : shmAllocOps) {
+      if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
+                                                    allocOp.getMemref())))
+        return;
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
+  return std::make_unique<OptimizeSharedMemoryPass>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..a1dc6cf70e7bf8
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Transform utilities ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+Operation::operand_range amdgpu::getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+    return vectorReadOp.getIndices();
+  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+    return vectorStoreOp.getIndices();
+  if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+    return transferReadOp.getIndices();
+  if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+    return transferWriteOp.getIndices();
+  llvm_unreachable("unsupported op type");
+}
+
+void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndicesMutable().assign(indices);
+  if (auto storeOp = dyn_cast<memref:...
[truncated]

@harsh-nod harsh-nod requested review from qedawkins, krzysz00 and ThomasRaoux and removed request for qedawkins December 18, 2023 22:08
@erman-gurses
Copy link
Contributor Author

Hi Everyone, please let me know what you think about this PR. Thanks.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have actual substantial comments later, but, as an initial matter, is a name as generic as "optimize shared memory access" appropriate here?

Also, I think, if you're trying to prevent aliasing issues around subviews, you might want to run a general alias analysis.

@erman-gurses
Copy link
Contributor Author

erman-gurses commented Dec 18, 2023

I'll have actual substantial comments later, but, as an initial matter, is a name as generic as "optimize shared memory access" appropriate here?

Can I get suggestion for that? I used this reference https://reviews.llvm.org/D127457 for the implementation and naming.

Copy link

github-actions bot commented Jan 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@erman-gurses erman-gurses force-pushed the eg_swizzling branch 3 times, most recently from ff49048 to 4c7c76a Compare January 10, 2024 16:33
@erman-gurses erman-gurses force-pushed the eg_swizzling branch 2 times, most recently from ad6ad24 to 14a07c5 Compare January 12, 2024 19:09
@krzysz00
Copy link
Contributor

@joker-eph I think we should land this even with the hardcoded list of memory operations because there's existing precedent, given that there's a followup plan for something like MemRefAccessInterface or the like.

@joker-eph
Copy link
Collaborator

"existing precedent" isn't a reason IMO. It's not because a mistake slipped that it justified continuing a bad pattern...
But if @erman-gurses wants to tackle this as a separate PR that's fine by me.

@erman-gurses
Copy link
Contributor Author

"existing precedent" isn't a reason IMO. It's not because a mistake slipped that it justified continuing a bad pattern... But if @erman-gurses wants to tackle this as a separate PR that's fine by me.

@joker-eph, yes, I would like to tackle this as a separate PR.

@erman-gurses
Copy link
Contributor Author

"existing precedent" isn't a reason IMO. It's not because a mistake slipped that it justified continuing a bad pattern... But if @erman-gurses wants to tackle this as a separate PR that's fine by me.

"existing precedent" isn't a reason IMO. It's not because a mistake slipped that it justified continuing a bad pattern... But if @erman-gurses wants to tackle this as a separate PR that's fine by me.

@joker-eph, yes, I would like to tackle this as a separate PR.

and @joker-eph, please let me know when you can approve this.

@joker-eph
Copy link
Collaborator

Probably best if someone more involved with AMDGPU reviews and approves?

@erman-gurses
Copy link
Contributor Author

erman-gurses commented Jan 19, 2024

Probably best if someone more involved with AMDGPU reviews and approves?

Sounds good, @krzysz00 has already reviewed it. Krzysztof, please let me know when you can approve this. Thanks.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable and we'll clean this and its NVPTX equivalent up in a followup, it sounds like.

@joker-eph
Copy link
Collaborator

Sorry, had to revert because of the bot breakage. Please see: https://lab.llvm.org/buildbot/#/builders/61/builds/53218 and feel free to reland when you have a fix!

@erman-gurses erman-gurses restored the eg_swizzling branch January 20, 2024 01:36
@erman-gurses
Copy link
Contributor Author

Sorry, had to revert because of the bot breakage. Please see: https://lab.llvm.org/buildbot/#/builders/61/builds/53218 and feel free to reland when you have a fix!

Sure, I will investigate and reland later.

@erman-gurses erman-gurses deleted the eg_swizzling branch January 23, 2024 16:14
harsh-nod pushed a commit that referenced this pull request Jan 25, 2024
- Reland: #75627

- Reproduced then fixed the build issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants