Skip to content

[mlir][MemRef] Add runtime bounds checking #75817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 22, 2023

Conversation

ryan-holt-1
Copy link
Contributor

This change adds (runtime) bounds checks for memref ops using the existing RuntimeVerifiableOpInterface. For memref.load and memref.store, we check that the indices are in-bounds of the memref's index space. For memref.reinterpret_cast and memref.subview we check that the resulting address space is in-bounds of the input memref's address space.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Ryan Holt (ryan-holt-1)

Changes

This change adds (runtime) bounds checks for memref ops using the existing RuntimeVerifiableOpInterface. For memref.load and memref.store, we check that the indices are in-bounds of the memref's index space. For memref.reinterpret_cast and memref.subview we check that the resulting address space is in-bounds of the input memref's address space.


Patch is 20.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75817.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+190-1)
  • (modified) mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir (+4-4)
  • (added) mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir (+67)
  • (added) mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir (+74)
  • (added) mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir (+71)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 05a069d98ef35f..c20c188703d392 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -8,10 +8,14 @@
 
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
 using namespace mlir;
@@ -21,6 +25,13 @@ static std::string generateErrorMessage(Operation *op, const std::string &msg) {
   std::string buffer;
   llvm::raw_string_ostream stream(buffer);
   OpPrintingFlags flags;
+  // We may generate a lot of error messages and so we need to ensure the
+  // printing is fast.
+  flags.assumeVerified();
+  flags.elideLargeElementsAttrs();
+  flags.printGenericOpForm();
+  flags.skipRegions();
+  flags.useLocalScope();
   stream << "ERROR: Runtime op verification failed\n";
   op->print(stream, flags);
   stream << "\n^ " << msg;
@@ -133,6 +144,179 @@ struct CastOpInterface
   }
 };
 
+template <typename LoadStoreOp>
+struct LoadStoreOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto loadStoreOp = cast<LoadStoreOp>(op);
+
+    // Verify that the indices on the load/store are in-bounds of the memref's
+    // index space
+
+    auto memref = loadStoreOp.getMemref();
+    auto rank = memref.getType().getRank();
+    if (rank == 0) {
+      return;
+    }
+    auto indices = loadStoreOp.getIndices();
+
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    Value assertCond;
+    for (auto i : llvm::seq<int64_t>(0, rank)) {
+      auto index = indices[i];
+
+      auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+
+      auto geLow = builder.createOrFold<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, index, zero);
+      auto ltHigh = builder.createOrFold<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::slt, index, dimOp);
+      auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
+
+      assertCond =
+          i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
+                : andOp;
+    }
+    builder.create<cf::AssertOp>(
+        loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+  }
+};
+
+// Compute the linear index for the provided strided layout and indices.
+Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
+                         ArrayRef<OpFoldResult> strides,
+                         ArrayRef<OpFoldResult> indices) {
+  auto [expr, values] = computeLinearIndex(offset, strides, indices);
+  auto index =
+      affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+  return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
+// Returns two Values representing the bounds of the provided strided layout
+// metadata. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+                                            OpFoldResult offset,
+                                            ArrayRef<OpFoldResult> strides,
+                                            ArrayRef<OpFoldResult> sizes) {
+  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
+  return {lowerBound, upperBound};
+}
+
+// Returns two Values representing the bounds of the address space of the
+// memref. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+                                            TypedValue<BaseMemRefType> memref) {
+  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+  auto offset = runtimeMetadata.getConstifiedMixedOffset();
+  auto strides = runtimeMetadata.getConstifiedMixedStrides();
+  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+  return computeLinearBounds(builder, loc, offset, strides, sizes);
+}
+
+struct ReinterpretCastOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          ReinterpretCastOpInterface, ReinterpretCastOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto reinterpretCast = cast<ReinterpretCastOp>(op);
+
+    // Verify that the resulting address space is in-bounds of the base memref's
+    // address space.
+
+    auto baseMemref = reinterpretCast.getSource();
+
+    auto castOffset = reinterpretCast.getMixedOffsets().front();
+    auto castStrides = reinterpretCast.getMixedStrides();
+    auto castSizes = reinterpretCast.getMixedSizes();
+
+    // Compute the bounds of the base memref's address space
+    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+
+    // Compute the bounds of the resulting memref's address space
+    auto [low, high] =
+        computeLinearBounds(builder, loc, castOffset, castStrides, castSizes);
+
+    auto geLow = builder.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sge, low, baseLow);
+
+    auto leHigh = builder.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sle, high, baseHigh);
+
+    auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
+
+    builder.create<cf::AssertOp>(
+        loc, assertCond,
+        generateErrorMessage(
+            op,
+            "result of reinterpret_cast is out-of-bounds of the base memref"));
+  }
+};
+
+struct SubViewOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
+                                                         SubViewOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto subView = cast<SubViewOp>(op);
+
+    // Verify that the resulting address space is in-bounds of the base memref's
+    // address space.
+
+    auto baseMemref = subView.getSource().cast<TypedValue<BaseMemRefType>>();
+
+    auto metadata = builder.create<ExtractStridedMetadataOp>(loc, baseMemref);
+
+    auto baseStrides = metadata.getConstifiedMixedStrides();
+    auto baseOffset = metadata.getConstifiedMixedOffset();
+
+    auto subStrides = subView.getMixedStrides();
+    auto subSizes = subView.getMixedSizes();
+    auto subOffsets = subView.getMixedOffsets();
+
+    // result_strides#i = baseStrides#i * subSizes#i
+    SmallVector<OpFoldResult> strides;
+    strides.reserve(baseStrides.size());
+    AffineExpr s0 = builder.getAffineSymbolExpr(0);
+    AffineExpr s1 = builder.getAffineSymbolExpr(1);
+    for (auto i : llvm::seq<int64_t>(0, baseMemref.getType().getRank())) {
+      strides.push_back(affine::makeComposedFoldedAffineApply(
+          builder, loc, s0 * s1, {subStrides[i], baseStrides[i]}));
+    }
+
+    // result_offset = baseOffset + sum(subOffset#i * baseStrides#i)
+    auto offset =
+        computeLinearIndex(builder, loc, baseOffset, subOffsets, baseStrides);
+
+    // result_sizes = subSizes
+    auto &sizes = subSizes;
+
+    // Compute the bounds of the base memref's address space
+    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+
+    // Compute the bounds of the resulting memref's address space
+    auto [low, high] =
+        computeLinearBounds(builder, loc, offset, strides, sizes);
+
+    auto geLow = builder.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sge, low, baseLow);
+
+    auto leHigh = builder.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sle, high, baseHigh);
+
+    auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
+
+    builder.create<cf::AssertOp>(
+        loc, assertCond,
+        generateErrorMessage(op,
+                             "subview is out-of-bounds of the base memref"));
+  }
+};
+
 struct ExpandShapeOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
                                                          ExpandShapeOp> {
@@ -183,8 +367,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
     CastOp::attachInterface<CastOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+    LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
+    ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
+    StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
+    SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
 
     // Load additional dialects of which ops may get created.
-    ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
+    ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
+                     cf::ControlFlowDialect>();
   });
 }
diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
index 6ad817a73408cd..52b8c16d753da7 100644
--- a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
@@ -33,26 +33,26 @@ func.func @main() {
   %alloc = memref.alloc() : memref<5xf32>
 
   //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32> to memref<10xf32>
+  // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32>) -> memref<10xf32>
   // CHECK-NEXT: ^ size mismatch of dim 0
   // CHECK-NEXT: Location: loc({{.*}})
   %1 = memref.cast %alloc : memref<5xf32> to memref<?xf32>
   func.call @cast_to_static_dim(%1) : (memref<?xf32>) -> (memref<10xf32>)
 
   // CHECK-NEXT: ERROR: Runtime op verification failed
-  // CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref<f32>
+  // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<*xf32>) -> memref<f32>
   // CHECK-NEXT: ^ rank mismatch
   // CHECK-NEXT: Location: loc({{.*}})
   %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32>
   func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref<f32>)
 
   // CHECK-NEXT: ERROR: Runtime op verification failed
-  // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
+  // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
   // CHECK-NEXT: ^ offset mismatch
   // CHECK-NEXT: Location: loc({{.*}})
 
   // CHECK-NEXT: ERROR: Runtime op verification failed
-  // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
+  // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
   // CHECK-NEXT: ^ stride mismatch of dim 0
   // CHECK-NEXT: Location: loc({{.*}})
   %4 = memref.cast %alloc
diff --git a/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir
new file mode 100644
index 00000000000000..169dfd70564594
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -finalize-memref-to-llvm \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-func-to-llvm \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @load(%memref: memref<1xf32>, %index: index) {
+    memref.load %memref[%index] :  memref<1xf32>
+    return
+}
+
+func.func @load_dynamic(%memref: memref<?xf32>, %index: index) {
+    memref.load %memref[%index] :  memref<?xf32>
+    return
+}
+
+func.func @load_nd_dynamic(%memref: memref<?x?x?xf32>, %index0: index, %index1: index, %index2: index) {
+    memref.load %memref[%index0, %index1, %index2] :  memref<?x?x?xf32>
+    return
+}
+
+func.func @main() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %n1 = arith.constant -1 : index
+  %2 = arith.constant 2 : index
+  %alloca_1 = memref.alloca() : memref<1xf32>
+  %alloc_1 = memref.alloc(%1) : memref<?xf32>
+  %alloc_2x2x2 = memref.alloc(%2, %2, %2) : memref<?x?x?xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> f32
+  // CHECK-NEXT: ^ out-of-bounds access
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @load(%alloca_1, %1) : (memref<1xf32>, index) -> ()
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?xf32>, index) -> f32
+  // CHECK-NEXT: ^ out-of-bounds access
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @load_dynamic(%alloc_1, %1) : (memref<?xf32>, index) -> ()
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?x?x?xf32>, index, index, index) -> f32
+  // CHECK-NEXT: ^ out-of-bounds access
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @load_nd_dynamic(%alloc_2x2x2, %1, %n1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @load(%alloca_1, %0) : (memref<1xf32>, index) -> ()
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @load_dynamic(%alloc_1, %0) : (memref<?xf32>, index) -> ()
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @load_nd_dynamic(%alloc_2x2x2, %1, %1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()
+
+  memref.dealloc %alloc_1 : memref<?xf32>
+  memref.dealloc %alloc_2x2x2 : memref<?x?x?xf32>
+
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
new file mode 100644
index 00000000000000..37002915405478
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -lower-affine \
+// RUN:     -finalize-memref-to-llvm \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-func-to-llvm \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
+    memref.reinterpret_cast %memref to
+                    offset: [%offset],
+                    sizes: [1],
+                    strides: [1]
+                  : memref<1xf32> to  memref<1xf32, strided<[1], offset: ?>>
+    return
+}
+
+func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index)  {
+    memref.reinterpret_cast %memref to
+                    offset: [%offset],
+                    sizes: [%size],
+                    strides: [%stride]
+                  : memref<?xf32> to  memref<?xf32, strided<[?], offset: ?>>
+    return
+}
+
+func.func @main() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %n1 = arith.constant -1 : index
+  %4 = arith.constant 4 : index
+  %5 = arith.constant 5 : index
+
+  %alloca_1 = memref.alloca() : memref<1xf32>
+  %alloc_4 = memref.alloc(%4) : memref<?xf32>
+
+  // Offset is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
+
+  // Offset is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
+
+  // Size is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
+
+  // Stride is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
+
+  //  CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
+
+  //  CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir
new file mode 100644
index 00000000000000..57fbb41e35bc1b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -lower-affine \
+// RUN:     -expand-strided-metadata \
+// RUN:     -finalize-memref-to-llvm \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-func-to-llvm \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @subview(%memref: memref<1xf32>, %offset: index) {
+    memref.subview %memref[%offset] [1] [1] : 
+        memref<1xf32> to 
+        memref<1xf32, strided<[1], offset: ?>>
+    return
+}
+
+func.func @subview_dynamic(%memref: memref<?x4xf32>, %offset: index, %size: index, %stride: index) {
+    memref.subview %memref[%offset, 0] [%size, 4] [%stride, 1] : 
+        memref<?x4xf32> to 
+        memref<?x4xf32, strided<[?, 1], offset: ?>>
+    return
+}
+
+func.func @main() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %n1 = arith.constant -1 : index
+  %4 = arith.constant 4 : index
+  %5 = arith.constant 5 : index
+
+  %alloca = memref.alloca() : memref<1xf32>
+  %alloc = memref.alloc(%4) : memref<?x4xf32>
+
+  // Offset is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.subview"
+  // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @subview(%alloca, %1) : (memref<1xf32>, index) -> ()
+
+  // Offset is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.subview"
+  // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{.*}})
+  func.call @subview(%alloca, %n1) : (memref<1xf32>, index) -> ()
+
+  // Size is out-of-bounds
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.subview"
+  // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+  // CHECK-NEXT: Location: loc({{...
[truncated]

@ryan-holt-1
Copy link
Contributor Author

@joker-eph
Copy link
Collaborator

joker-eph commented Dec 19, 2023

Thanks! LG with one comment, but I'll let @matthias-springer also have a look and merge.

This change adds (runtime) bounds checks for `memref` ops using the
existing `RuntimeVerifiableOpInterface`. For `memref.load` and
`memref.store`, we check that the indices are in-bounds of the memref's
index space. For `memref.reinterpret_cast` and `memref.subview` we check
that the linear bounds of the resulting memref are within the linear
bounds of the base memref. Note that this does not implement full
runtime verification for `memref.subview`. Future work would verify the
bounds on a per-dimension basis.
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@matthias-springer matthias-springer merged commit 847a6f8 into llvm:main Dec 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants