Skip to content

Commit 847a6f8

Browse files
authored
[mlir][MemRef] Add runtime bounds checking (#75817)
This change adds (runtime) bounds checks for `memref` ops using the existing `RuntimeVerifiableOpInterface`. For `memref.load` and `memref.store`, we check that the indices are in-bounds of the memref's index space. For `memref.reinterpret_cast` and `memref.subview` we check that the resulting address space is in-bounds of the input memref's address space.
1 parent 62d8ae0 commit 847a6f8

File tree

5 files changed

+405
-5
lines changed

5 files changed

+405
-5
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1214
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1315
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1416
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1519
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
1620

1721
using namespace mlir;
@@ -21,6 +25,12 @@ static std::string generateErrorMessage(Operation *op, const std::string &msg) {
2125
std::string buffer;
2226
llvm::raw_string_ostream stream(buffer);
2327
OpPrintingFlags flags;
28+
// We may generate a lot of error messages and so we need to ensure the
29+
// printing is fast.
30+
flags.elideLargeElementsAttrs();
31+
flags.printGenericOpForm();
32+
flags.skipRegions();
33+
flags.useLocalScope();
2434
stream << "ERROR: Runtime op verification failed\n";
2535
op->print(stream, flags);
2636
stream << "\n^ " << msg;
@@ -133,6 +143,161 @@ struct CastOpInterface
133143
}
134144
};
135145

146+
/// Verifies that the indices on load/store ops are in-bounds of the memref's
147+
/// index space: 0 <= index#i < dim#i
148+
template <typename LoadStoreOp>
149+
struct LoadStoreOpInterface
150+
: public RuntimeVerifiableOpInterface::ExternalModel<
151+
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
152+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
153+
Location loc) const {
154+
auto loadStoreOp = cast<LoadStoreOp>(op);
155+
156+
auto memref = loadStoreOp.getMemref();
157+
auto rank = memref.getType().getRank();
158+
if (rank == 0) {
159+
return;
160+
}
161+
auto indices = loadStoreOp.getIndices();
162+
163+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
164+
Value assertCond;
165+
for (auto i : llvm::seq<int64_t>(0, rank)) {
166+
auto index = indices[i];
167+
168+
auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
169+
170+
auto geLow = builder.createOrFold<arith::CmpIOp>(
171+
loc, arith::CmpIPredicate::sge, index, zero);
172+
auto ltHigh = builder.createOrFold<arith::CmpIOp>(
173+
loc, arith::CmpIPredicate::slt, index, dimOp);
174+
auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
175+
176+
assertCond =
177+
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
178+
: andOp;
179+
}
180+
builder.create<cf::AssertOp>(
181+
loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
182+
}
183+
};
184+
185+
/// Compute the linear index for the provided strided layout and indices.
186+
Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
187+
ArrayRef<OpFoldResult> strides,
188+
ArrayRef<OpFoldResult> indices) {
189+
auto [expr, values] = computeLinearIndex(offset, strides, indices);
190+
auto index =
191+
affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
192+
return getValueOrCreateConstantIndexOp(builder, loc, index);
193+
}
194+
195+
/// Returns two Values representing the bounds of the provided strided layout
196+
/// metadata. The bounds are returned as a half open interval -- [low, high).
197+
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
198+
OpFoldResult offset,
199+
ArrayRef<OpFoldResult> strides,
200+
ArrayRef<OpFoldResult> sizes) {
201+
auto zeros = SmallVector<int64_t>(sizes.size(), 0);
202+
auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
203+
auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
204+
auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
205+
return {lowerBound, upperBound};
206+
}
207+
208+
/// Returns two Values representing the bounds of the memref. The bounds are
209+
/// returned as a half open interval -- [low, high).
210+
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
211+
TypedValue<BaseMemRefType> memref) {
212+
auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
213+
auto offset = runtimeMetadata.getConstifiedMixedOffset();
214+
auto strides = runtimeMetadata.getConstifiedMixedStrides();
215+
auto sizes = runtimeMetadata.getConstifiedMixedSizes();
216+
return computeLinearBounds(builder, loc, offset, strides, sizes);
217+
}
218+
219+
/// Verifies that the linear bounds of a reinterpret_cast op are within the
220+
/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
221+
struct ReinterpretCastOpInterface
222+
: public RuntimeVerifiableOpInterface::ExternalModel<
223+
ReinterpretCastOpInterface, ReinterpretCastOp> {
224+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
225+
Location loc) const {
226+
auto reinterpretCast = cast<ReinterpretCastOp>(op);
227+
auto baseMemref = reinterpretCast.getSource();
228+
auto resultMemref =
229+
cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
230+
231+
builder.setInsertionPointAfter(op);
232+
233+
// Compute the linear bounds of the base memref
234+
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
235+
236+
// Compute the linear bounds of the resulting memref
237+
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
238+
239+
// Check low >= baseLow
240+
auto geLow = builder.createOrFold<arith::CmpIOp>(
241+
loc, arith::CmpIPredicate::sge, low, baseLow);
242+
243+
// Check high <= baseHigh
244+
auto leHigh = builder.createOrFold<arith::CmpIOp>(
245+
loc, arith::CmpIPredicate::sle, high, baseHigh);
246+
247+
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
248+
249+
builder.create<cf::AssertOp>(
250+
loc, assertCond,
251+
generateErrorMessage(
252+
op,
253+
"result of reinterpret_cast is out-of-bounds of the base memref"));
254+
}
255+
};
256+
257+
/// Verifies that the linear bounds of a subview op are within the linear bounds
258+
/// of the base memref: low >= baseLow && high <= baseHigh
259+
/// TODO: This is not yet a full runtime verification of subview. For example,
260+
/// consider:
261+
/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
262+
/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
263+
/// : memref<?x?xf32> to memref<?x?xf32>
264+
/// The subview is in-bounds of the entire base memref but the first dimension
265+
/// is out-of-bounds. Future work would verify the bounds on a per-dimension
266+
/// basis.
267+
struct SubViewOpInterface
268+
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
269+
SubViewOp> {
270+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
271+
Location loc) const {
272+
auto subView = cast<SubViewOp>(op);
273+
auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
274+
auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
275+
276+
builder.setInsertionPointAfter(op);
277+
278+
// Compute the linear bounds of the base memref
279+
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
280+
281+
// Compute the linear bounds of the resulting memref
282+
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
283+
284+
// Check low >= baseLow
285+
auto geLow = builder.createOrFold<arith::CmpIOp>(
286+
loc, arith::CmpIPredicate::sge, low, baseLow);
287+
288+
// Check high <= baseHigh
289+
auto leHigh = builder.createOrFold<arith::CmpIOp>(
290+
loc, arith::CmpIPredicate::sle, high, baseHigh);
291+
292+
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
293+
294+
builder.create<cf::AssertOp>(
295+
loc, assertCond,
296+
generateErrorMessage(op,
297+
"subview is out-of-bounds of the base memref"));
298+
}
299+
};
300+
136301
struct ExpandShapeOpInterface
137302
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
138303
ExpandShapeOp> {
@@ -183,8 +348,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
183348
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
184349
CastOp::attachInterface<CastOpInterface>(*ctx);
185350
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
351+
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
352+
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
353+
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
354+
SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
186355

187356
// Load additional dialects of which ops may get created.
188-
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
357+
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
358+
cf::ControlFlowDialect>();
189359
});
190360
}

mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,26 @@ func.func @main() {
3333
%alloc = memref.alloc() : memref<5xf32>
3434

3535
// CHECK: ERROR: Runtime op verification failed
36-
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32> to memref<10xf32>
36+
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32>) -> memref<10xf32>
3737
// CHECK-NEXT: ^ size mismatch of dim 0
3838
// CHECK-NEXT: Location: loc({{.*}})
3939
%1 = memref.cast %alloc : memref<5xf32> to memref<?xf32>
4040
func.call @cast_to_static_dim(%1) : (memref<?xf32>) -> (memref<10xf32>)
4141

4242
// CHECK-NEXT: ERROR: Runtime op verification failed
43-
// CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref<f32>
43+
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<*xf32>) -> memref<f32>
4444
// CHECK-NEXT: ^ rank mismatch
4545
// CHECK-NEXT: Location: loc({{.*}})
4646
%3 = memref.cast %alloc : memref<5xf32> to memref<*xf32>
4747
func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref<f32>)
4848

4949
// CHECK-NEXT: ERROR: Runtime op verification failed
50-
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
50+
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
5151
// CHECK-NEXT: ^ offset mismatch
5252
// CHECK-NEXT: Location: loc({{.*}})
5353

5454
// CHECK-NEXT: ERROR: Runtime op verification failed
55-
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
55+
// CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32, strided<[9], offset: 5>>
5656
// CHECK-NEXT: ^ stride mismatch of dim 0
5757
// CHECK-NEXT: Location: loc({{.*}})
5858
%4 = memref.cast %alloc
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -expand-strided-metadata \
3+
// RUN: -finalize-memref-to-llvm \
4+
// RUN: -test-cf-assert \
5+
// RUN: -convert-func-to-llvm \
6+
// RUN: -reconcile-unrealized-casts | \
7+
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
8+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
9+
// RUN: FileCheck %s
10+
11+
func.func @load(%memref: memref<1xf32>, %index: index) {
12+
memref.load %memref[%index] : memref<1xf32>
13+
return
14+
}
15+
16+
func.func @load_dynamic(%memref: memref<?xf32>, %index: index) {
17+
memref.load %memref[%index] : memref<?xf32>
18+
return
19+
}
20+
21+
func.func @load_nd_dynamic(%memref: memref<?x?x?xf32>, %index0: index, %index1: index, %index2: index) {
22+
memref.load %memref[%index0, %index1, %index2] : memref<?x?x?xf32>
23+
return
24+
}
25+
26+
func.func @main() {
27+
%0 = arith.constant 0 : index
28+
%1 = arith.constant 1 : index
29+
%n1 = arith.constant -1 : index
30+
%2 = arith.constant 2 : index
31+
%alloca_1 = memref.alloca() : memref<1xf32>
32+
%alloc_1 = memref.alloc(%1) : memref<?xf32>
33+
%alloc_2x2x2 = memref.alloc(%2, %2, %2) : memref<?x?x?xf32>
34+
35+
// CHECK: ERROR: Runtime op verification failed
36+
// CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> f32
37+
// CHECK-NEXT: ^ out-of-bounds access
38+
// CHECK-NEXT: Location: loc({{.*}})
39+
func.call @load(%alloca_1, %1) : (memref<1xf32>, index) -> ()
40+
41+
// CHECK: ERROR: Runtime op verification failed
42+
// CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?xf32>, index) -> f32
43+
// CHECK-NEXT: ^ out-of-bounds access
44+
// CHECK-NEXT: Location: loc({{.*}})
45+
func.call @load_dynamic(%alloc_1, %1) : (memref<?xf32>, index) -> ()
46+
47+
// CHECK: ERROR: Runtime op verification failed
48+
// CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<?x?x?xf32>, index, index, index) -> f32
49+
// CHECK-NEXT: ^ out-of-bounds access
50+
// CHECK-NEXT: Location: loc({{.*}})
51+
func.call @load_nd_dynamic(%alloc_2x2x2, %1, %n1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()
52+
53+
// CHECK-NOT: ERROR: Runtime op verification failed
54+
func.call @load(%alloca_1, %0) : (memref<1xf32>, index) -> ()
55+
56+
// CHECK-NOT: ERROR: Runtime op verification failed
57+
func.call @load_dynamic(%alloc_1, %0) : (memref<?xf32>, index) -> ()
58+
59+
// CHECK-NOT: ERROR: Runtime op verification failed
60+
func.call @load_nd_dynamic(%alloc_2x2x2, %1, %1, %0) : (memref<?x?x?xf32>, index, index, index) -> ()
61+
62+
memref.dealloc %alloc_1 : memref<?xf32>
63+
memref.dealloc %alloc_2x2x2 : memref<?x?x?xf32>
64+
65+
return
66+
}
67+
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -lower-affine \
3+
// RUN: -finalize-memref-to-llvm \
4+
// RUN: -test-cf-assert \
5+
// RUN: -convert-func-to-llvm \
6+
// RUN: -reconcile-unrealized-casts | \
7+
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
8+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
9+
// RUN: FileCheck %s
10+
11+
func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
12+
memref.reinterpret_cast %memref to
13+
offset: [%offset],
14+
sizes: [1],
15+
strides: [1]
16+
: memref<1xf32> to memref<1xf32, strided<[1], offset: ?>>
17+
return
18+
}
19+
20+
func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index) {
21+
memref.reinterpret_cast %memref to
22+
offset: [%offset],
23+
sizes: [%size],
24+
strides: [%stride]
25+
: memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
26+
return
27+
}
28+
29+
func.func @main() {
30+
%0 = arith.constant 0 : index
31+
%1 = arith.constant 1 : index
32+
%n1 = arith.constant -1 : index
33+
%4 = arith.constant 4 : index
34+
%5 = arith.constant 5 : index
35+
36+
%alloca_1 = memref.alloca() : memref<1xf32>
37+
%alloc_4 = memref.alloc(%4) : memref<?xf32>
38+
39+
// Offset is out-of-bounds
40+
// CHECK: ERROR: Runtime op verification failed
41+
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
42+
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
43+
// CHECK-NEXT: Location: loc({{.*}})
44+
func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
45+
46+
// Offset is out-of-bounds
47+
// CHECK: ERROR: Runtime op verification failed
48+
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
49+
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
50+
// CHECK-NEXT: Location: loc({{.*}})
51+
func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
52+
53+
// Size is out-of-bounds
54+
// CHECK: ERROR: Runtime op verification failed
55+
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
56+
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
57+
// CHECK-NEXT: Location: loc({{.*}})
58+
func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
59+
60+
// Stride is out-of-bounds
61+
// CHECK: ERROR: Runtime op verification failed
62+
// CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
63+
// CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
64+
// CHECK-NEXT: Location: loc({{.*}})
65+
func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
66+
67+
// CHECK-NOT: ERROR: Runtime op verification failed
68+
func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
69+
70+
// CHECK-NOT: ERROR: Runtime op verification failed
71+
func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
72+
73+
return
74+
}

0 commit comments

Comments
 (0)