-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][MemRef] Add runtime bounds checking #75817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Ryan Holt (ryan-holt-1) ChangesThis change adds (runtime) bounds checks for Patch is 20.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75817.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 05a069d98ef35f..c20c188703d392 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -8,10 +8,14 @@
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
@@ -21,6 +25,13 @@ static std::string generateErrorMessage(Operation *op, const std::string &msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
+ // We may generate a lot of error messages and so we need to ensure the
+ // printing is fast.
+ flags.assumeVerified();
+ flags.elideLargeElementsAttrs();
+ flags.printGenericOpForm();
+ flags.skipRegions();
+ flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
op->print(stream, flags);
stream << "\n^ " << msg;
@@ -133,6 +144,179 @@ struct CastOpInterface
}
};
+template <typename LoadStoreOp>
+struct LoadStoreOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<
+ LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto loadStoreOp = cast<LoadStoreOp>(op);
+
+ // Verify that the indices on the load/store are in-bounds of the memref's
+ // index space
+
+ auto memref = loadStoreOp.getMemref();
+ auto rank = memref.getType().getRank();
+ if (rank == 0) {
+ return;
+ }
+ auto indices = loadStoreOp.getIndices();
+
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value assertCond;
+ for (auto i : llvm::seq<int64_t>(0, rank)) {
+ auto index = indices[i];
+
+ auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+
+ auto geLow = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, index, zero);
+ auto ltHigh = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, index, dimOp);
+ auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
+
+ assertCond =
+ i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
+ : andOp;
+ }
+ builder.create<cf::AssertOp>(
+ loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+ }
+};
+
+// Compute the linear index for the provided strided layout and indices.
+Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ auto [expr, values] = computeLinearIndex(offset, strides, indices);
+ auto index =
+ affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+ return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
+// Returns two Values representing the bounds of the provided strided layout
+// metadata. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+ OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> sizes) {
+ auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+ auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+ auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+ auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
+ return {lowerBound, upperBound};
+}
+
+// Returns two Values representing the bounds of the address space of the
+// memref. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+ TypedValue<BaseMemRefType> memref) {
+ auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+ auto offset = runtimeMetadata.getConstifiedMixedOffset();
+ auto strides = runtimeMetadata.getConstifiedMixedStrides();
+ auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+ return computeLinearBounds(builder, loc, offset, strides, sizes);
+}
+
+struct ReinterpretCastOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<
+ ReinterpretCastOpInterface, ReinterpretCastOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto reinterpretCast = cast<ReinterpretCastOp>(op);
+
+ // Verify that the resulting address space is in-bounds of the base memref's
+ // address space.
+
+ auto baseMemref = reinterpretCast.getSource();
+
+ auto castOffset = reinterpretCast.getMixedOffsets().front();
+ auto castStrides = reinterpretCast.getMixedStrides();
+ auto castSizes = reinterpretCast.getMixedSizes();
+
+ // Compute the bounds of the base memref's address space
+ auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+
+ // Compute the bounds of the resulting memref's address space
+ auto [low, high] =
+ computeLinearBounds(builder, loc, castOffset, castStrides, castSizes);
+
+ auto geLow = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, low, baseLow);
+
+ auto leHigh = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sle, high, baseHigh);
+
+ auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
+
+ builder.create<cf::AssertOp>(
+ loc, assertCond,
+ generateErrorMessage(
+ op,
+ "result of reinterpret_cast is out-of-bounds of the base memref"));
+ }
+};
+
+struct SubViewOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
+ SubViewOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto subView = cast<SubViewOp>(op);
+
+ // Verify that the resulting address space is in-bounds of the base memref's
+ // address space.
+
+ auto baseMemref = subView.getSource().cast<TypedValue<BaseMemRefType>>();
+
+ auto metadata = builder.create<ExtractStridedMetadataOp>(loc, baseMemref);
+
+ auto baseStrides = metadata.getConstifiedMixedStrides();
+ auto baseOffset = metadata.getConstifiedMixedOffset();
+
+ auto subStrides = subView.getMixedStrides();
+ auto subSizes = subView.getMixedSizes();
+ auto subOffsets = subView.getMixedOffsets();
+
+ // result_strides#i = baseStrides#i * subSizes#i
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(baseStrides.size());
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ for (auto i : llvm::seq<int64_t>(0, baseMemref.getType().getRank())) {
+ strides.push_back(affine::makeComposedFoldedAffineApply(
+ builder, loc, s0 * s1, {subStrides[i], baseStrides[i]}));
+ }
+
+ // result_offset = baseOffset + sum(subOffset#i * baseStrides#i)
+ auto offset =
+ computeLinearIndex(builder, loc, baseOffset, subOffsets, baseStrides);
+
+ // result_sizes = subSizes
+ auto &sizes = subSizes;
+
+ // Compute the bounds of the base memref's address space
+ auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+
+ // Compute the bounds of the resulting memref's address space
+ auto [low, high] =
+ computeLinearBounds(builder, loc, offset, strides, sizes);
+
+ auto geLow = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, low, baseLow);
+
+ auto leHigh = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sle, high, baseHigh);
+
+ auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
+
+ builder.create<cf::AssertOp>(
+ loc, assertCond,
+ generateErrorMessage(op,
+ "subview is out-of-bounds of the base memref"));
+ }
+};
+
struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
@@ -183,8 +367,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+ LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
+ ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
+ StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
+ SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
// Load additional dialects of which ops may get created.
- ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
+ ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
+ cf::ControlFlowDialect>();
});
}
diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
index 6ad817a73408cd..52b8c16d753da7 100644
--- a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
@@ -33,26 +33,26 @@ func.func @main() {
%alloc = memref.alloc() : memref<5xf32>
// CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32> to memref<10xf32>
+ // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32>) -> memref<10xf32>
// CHECK-NEXT: ^ size mismatch of dim 0
// CHECK-NEXT: Location: loc({{.*}})
%1 = memref.cast %alloc : memref<5xf32> to memref<?xf32>
func.call @cast_to_static_dim(%1) : (memref<?xf32>) -> (memref<10xf32>)
// CHECK-NEXT: ERROR: Runtime op verification failed
- // CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref<f32>
+ // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<*xf32>) -> memref<f32>
// CHECK-NEXT: ^ rank mismatch
// CHECK-NEXT: Location: loc({{.*}})
%3 = memref.cast %alloc : memref<5xf32> to memref<*xf32>
func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref<f32>)
// CHECK-NEXT: ERROR: Runtime op verification failed
- // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
+ // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
// CHECK-NEXT: ^ offset mismatch
// CHECK-NEXT: Location: loc({{.*}})
// CHECK-NEXT: ERROR: Runtime op verification failed
- // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
+ // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
// CHECK-NEXT: ^ stride mismatch of dim 0
// CHECK-NEXT: Location: loc({{.*}})
%4 = memref.cast %alloc
diff --git a/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir
new file mode 100644
index 00000000000000..169dfd70564594
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -test-cf-assert \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @load(%memref: memref<1xf32>, %index: index) {
+ memref.load %memref[%index] : memref<1xf32>
+ return
+}
+
+func.func @load_dynamic(%memref: memref<?xf32>, %index: index) {
+ memref.load %memref[%index] : memref<?xf32>
+ return
+}
+
+func.func @load_nd_dynamic(%memref: memref<?x?x?xf32>, %index0: index, %index1: index, %index2: index) {
+ memref.load %memref[%index0, %index1, %index2] : memref<?x?x?xf32>
+ return
+}
+
+func.func @main() {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %n1 = arith.constant -1 : index
+ %2 = arith.constant 2 : index
+ %alloca_1 = memref.alloca() : memref<1xf32>
+ %alloc_1 = memref.alloc(%1) : memref<?xf32>
+ %alloc_2x2x2 = memref.alloc(%2, %2, %2) : memref<?x?x?xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> f32
+ // CHECK-NEXT: ^ out-of-bounds access
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @load(%alloca_1, %1) : (memref<1xf32>, index) -> ()
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?xf32>, index) -> f32
+ // CHECK-NEXT: ^ out-of-bounds access
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @load_dynamic(%alloc_1, %1) : (memref<?xf32>, index) -> ()
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?x?x?xf32>, index, index, index) -> f32
+ // CHECK-NEXT: ^ out-of-bounds access
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @load_nd_dynamic(%alloc_2x2x2, %1, %n1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @load(%alloca_1, %0) : (memref<1xf32>, index) -> ()
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @load_dynamic(%alloc_1, %0) : (memref<?xf32>, index) -> ()
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @load_nd_dynamic(%alloc_2x2x2, %1, %1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()
+
+ memref.dealloc %alloc_1 : memref<?xf32>
+ memref.dealloc %alloc_2x2x2 : memref<?x?x?xf32>
+
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
new file mode 100644
index 00000000000000..37002915405478
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -lower-affine \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -test-cf-assert \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
+ memref.reinterpret_cast %memref to
+ offset: [%offset],
+ sizes: [1],
+ strides: [1]
+ : memref<1xf32> to memref<1xf32, strided<[1], offset: ?>>
+ return
+}
+
+func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index) {
+ memref.reinterpret_cast %memref to
+ offset: [%offset],
+ sizes: [%size],
+ strides: [%stride]
+ : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
+ return
+}
+
+func.func @main() {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %n1 = arith.constant -1 : index
+ %4 = arith.constant 4 : index
+ %5 = arith.constant 5 : index
+
+ %alloca_1 = memref.alloca() : memref<1xf32>
+ %alloc_4 = memref.alloc(%4) : memref<?xf32>
+
+ // Offset is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+ // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
+
+ // Offset is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+ // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
+
+ // Size is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+ // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
+
+ // Stride is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
+ // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir
new file mode 100644
index 00000000000000..57fbb41e35bc1b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -lower-affine \
+// RUN: -expand-strided-metadata \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -test-cf-assert \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @subview(%memref: memref<1xf32>, %offset: index) {
+ memref.subview %memref[%offset] [1] [1] :
+ memref<1xf32> to
+ memref<1xf32, strided<[1], offset: ?>>
+ return
+}
+
+func.func @subview_dynamic(%memref: memref<?x4xf32>, %offset: index, %size: index, %stride: index) {
+ memref.subview %memref[%offset, 0] [%size, 4] [%stride, 1] :
+ memref<?x4xf32> to
+ memref<?x4xf32, strided<[?, 1], offset: ?>>
+ return
+}
+
+func.func @main() {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %n1 = arith.constant -1 : index
+ %4 = arith.constant 4 : index
+ %5 = arith.constant 5 : index
+
+ %alloca = memref.alloca() : memref<1xf32>
+ %alloc = memref.alloc(%4) : memref<?x4xf32>
+
+ // Offset is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.subview"
+ // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @subview(%alloca, %1) : (memref<1xf32>, index) -> ()
+
+ // Offset is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.subview"
+ // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{.*}})
+ func.call @subview(%alloca, %n1) : (memref<1xf32>, index) -> ()
+
+ // Size is out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.subview"
+ // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: Location: loc({{...
[truncated]
|
Thanks! LG with one comment, but I'll let @matthias-springer also have a look and merge. |
This change adds (runtime) bounds checks for `memref` ops using the existing `RuntimeVerifiableOpInterface`. For `memref.load` and `memref.store`, we check that the indices are in-bounds of the memref's index space. For `memref.reinterpret_cast` and `memref.subview` we check that the linear bounds of the resulting memref are within the linear bounds of the base memref. Note that this does not implement full runtime verification for `memref.subview`. Future work would verify the bounds on a per-dimension basis.
eab5f8d
to
1fe0373
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
This change adds (runtime) bounds checks for
memref
ops using the existingRuntimeVerifiableOpInterface
. Formemref.load
andmemref.store
, we check that the indices are in-bounds of the memref's index space. Formemref.reinterpret_cast
andmemref.subview
we check that the resulting address space is in-bounds of the input memref's address space.