Skip to content

[mlir][MemRef] Add runtime bounds checking #75817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 171 additions & 1 deletion mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"

using namespace mlir;
Expand All @@ -21,6 +25,12 @@ static std::string generateErrorMessage(Operation *op, const std::string &msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
// We may generate a lot of error messages and so we need to ensure the
// printing is fast.
flags.elideLargeElementsAttrs();
flags.printGenericOpForm();
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
op->print(stream, flags);
stream << "\n^ " << msg;
Expand Down Expand Up @@ -133,6 +143,161 @@ struct CastOpInterface
}
};

/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto loadStoreOp = cast<LoadStoreOp>(op);

auto memref = loadStoreOp.getMemref();
auto rank = memref.getType().getRank();
if (rank == 0) {
return;
}
auto indices = loadStoreOp.getIndices();

auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
auto index = indices[i];

auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);

auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, zero);
auto ltHigh = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, index, dimOp);
auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);

assertCond =
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
: andOp;
}
builder.create<cf::AssertOp>(
loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
}
};

/// Compute the linear index for the provided strided layout and indices.
Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
ArrayRef<OpFoldResult> strides,
ArrayRef<OpFoldResult> indices) {
auto [expr, values] = computeLinearIndex(offset, strides, indices);
auto index =
affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
return getValueOrCreateConstantIndexOp(builder, loc, index);
}

/// Returns two Values representing the bounds of the provided strided layout
/// metadata. The bounds are returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
OpFoldResult offset,
ArrayRef<OpFoldResult> strides,
ArrayRef<OpFoldResult> sizes) {
auto zeros = SmallVector<int64_t>(sizes.size(), 0);
auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
return {lowerBound, upperBound};
}

/// Returns two Values representing the bounds of the memref. The bounds are
/// returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
TypedValue<BaseMemRefType> memref) {
auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
auto offset = runtimeMetadata.getConstifiedMixedOffset();
auto strides = runtimeMetadata.getConstifiedMixedStrides();
auto sizes = runtimeMetadata.getConstifiedMixedSizes();
return computeLinearBounds(builder, loc, offset, strides, sizes);
}

/// Verifies that the linear bounds of a reinterpret_cast op are within the
/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());

builder.setInsertionPointAfter(op);

// Compute the linear bounds of the base memref
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);

// Compute the linear bounds of the resulting memref
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);

// Check low >= baseLow
auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, low, baseLow);

// Check high <= baseHigh
auto leHigh = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, high, baseHigh);

auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);

builder.create<cf::AssertOp>(
loc, assertCond,
generateErrorMessage(
op,
"result of reinterpret_cast is out-of-bounds of the base memref"));
}
};

/// Verifies that the linear bounds of a subview op are within the linear bounds
/// of the base memref: low >= baseLow && high <= baseHigh
/// TODO: This is not yet a full runtime verification of subview. For example,
/// consider:
/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
/// : memref<?x?xf32> to memref<?x?xf32>
/// The subview is in-bounds of the entire base memref but the first dimension
/// is out-of-bounds. Future work would verify the bounds on a per-dimension
/// basis.
struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto subView = cast<SubViewOp>(op);
auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());

builder.setInsertionPointAfter(op);

// Compute the linear bounds of the base memref
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);

// Compute the linear bounds of the resulting memref
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);

// Check low >= baseLow
auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, low, baseLow);

// Check high <= baseHigh
auto leHigh = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, high, baseHigh);

auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);

builder.create<cf::AssertOp>(
loc, assertCond,
generateErrorMessage(op,
"subview is out-of-bounds of the base memref"));
}
};

struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
Expand Down Expand Up @@ -183,8 +348,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
SubViewOp::attachInterface<SubViewOpInterface>(*ctx);

// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
cf::ControlFlowDialect>();
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@ func.func @main() {
%alloc = memref.alloc() : memref<5xf32>

// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32> to memref<10xf32>
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32>) -> memref<10xf32>
// CHECK-NEXT: ^ size mismatch of dim 0
// CHECK-NEXT: Location: loc({{.*}})
%1 = memref.cast %alloc : memref<5xf32> to memref<?xf32>
func.call @cast_to_static_dim(%1) : (memref<?xf32>) -> (memref<10xf32>)

// CHECK-NEXT: ERROR: Runtime op verification failed
// CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref<f32>
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<*xf32>) -> memref<f32>
// CHECK-NEXT: ^ rank mismatch
// CHECK-NEXT: Location: loc({{.*}})
%3 = memref.cast %alloc : memref<5xf32> to memref<*xf32>
func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref<f32>)

// CHECK-NEXT: ERROR: Runtime op verification failed
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
// CHECK-NEXT: ^ offset mismatch
// CHECK-NEXT: Location: loc({{.*}})

// CHECK-NEXT: ERROR: Runtime op verification failed
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
// CHECK-NEXT: ^ stride mismatch of dim 0
// CHECK-NEXT: Location: loc({{.*}})
%4 = memref.cast %alloc
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s

func.func @load(%memref: memref<1xf32>, %index: index) {
memref.load %memref[%index] : memref<1xf32>
return
}

func.func @load_dynamic(%memref: memref<?xf32>, %index: index) {
memref.load %memref[%index] : memref<?xf32>
return
}

func.func @load_nd_dynamic(%memref: memref<?x?x?xf32>, %index0: index, %index1: index, %index2: index) {
memref.load %memref[%index0, %index1, %index2] : memref<?x?x?xf32>
return
}

func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%n1 = arith.constant -1 : index
%2 = arith.constant 2 : index
%alloca_1 = memref.alloca() : memref<1xf32>
%alloc_1 = memref.alloc(%1) : memref<?xf32>
%alloc_2x2x2 = memref.alloc(%2, %2, %2) : memref<?x?x?xf32>

// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> f32
// CHECK-NEXT: ^ out-of-bounds access
// CHECK-NEXT: Location: loc({{.*}})
func.call @load(%alloca_1, %1) : (memref<1xf32>, index) -> ()

// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?xf32>, index) -> f32
// CHECK-NEXT: ^ out-of-bounds access
// CHECK-NEXT: Location: loc({{.*}})
func.call @load_dynamic(%alloc_1, %1) : (memref<?xf32>, index) -> ()

// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?x?x?xf32>, index, index, index) -> f32
// CHECK-NEXT: ^ out-of-bounds access
// CHECK-NEXT: Location: loc({{.*}})
func.call @load_nd_dynamic(%alloc_2x2x2, %1, %n1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()

// CHECK-NOT: ERROR: Runtime op verification failed
func.call @load(%alloca_1, %0) : (memref<1xf32>, index) -> ()

// CHECK-NOT: ERROR: Runtime op verification failed
func.call @load_dynamic(%alloc_1, %0) : (memref<?xf32>, index) -> ()

// CHECK-NOT: ERROR: Runtime op verification failed
func.call @load_nd_dynamic(%alloc_2x2x2, %1, %1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()

memref.dealloc %alloc_1 : memref<?xf32>
memref.dealloc %alloc_2x2x2 : memref<?x?x?xf32>

return
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -lower-affine \
// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s

func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
memref.reinterpret_cast %memref to
offset: [%offset],
sizes: [1],
strides: [1]
: memref<1xf32> to memref<1xf32, strided<[1], offset: ?>>
return
}

func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index) {
memref.reinterpret_cast %memref to
offset: [%offset],
sizes: [%size],
strides: [%stride]
: memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
return
}

func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%n1 = arith.constant -1 : index
%4 = arith.constant 4 : index
%5 = arith.constant 5 : index

%alloca_1 = memref.alloca() : memref<1xf32>
%alloc_4 = memref.alloc(%4) : memref<?xf32>

// Offset is out-of-bounds
// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
// CHECK-NEXT: Location: loc({{.*}})
func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()

// Offset is out-of-bounds
// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
// CHECK-NEXT: Location: loc({{.*}})
func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()

// Size is out-of-bounds
// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
// CHECK-NEXT: Location: loc({{.*}})
func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()

// Stride is out-of-bounds
// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
// CHECK-NEXT: Location: loc({{.*}})
func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()

// CHECK-NOT: ERROR: Runtime op verification failed
func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()

// CHECK-NOT: ERROR: Runtime op verification failed
func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()

return
}
Loading