8
8
9
9
#include " mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
10
10
11
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
11
12
#include " mlir/Dialect/Arith/IR/Arith.h"
13
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
12
14
#include " mlir/Dialect/ControlFlow/IR/ControlFlow.h"
13
15
#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
15
19
#include " mlir/Interfaces/RuntimeVerifiableOpInterface.h"
16
20
17
21
using namespace mlir ;
@@ -21,6 +25,12 @@ static std::string generateErrorMessage(Operation *op, const std::string &msg) {
21
25
std::string buffer;
22
26
llvm::raw_string_ostream stream (buffer);
23
27
OpPrintingFlags flags;
28
+ // We may generate a lot of error messages and so we need to ensure the
29
+ // printing is fast.
30
+ flags.elideLargeElementsAttrs ();
31
+ flags.printGenericOpForm ();
32
+ flags.skipRegions ();
33
+ flags.useLocalScope ();
24
34
stream << " ERROR: Runtime op verification failed\n " ;
25
35
op->print (stream, flags);
26
36
stream << " \n ^ " << msg;
@@ -133,6 +143,161 @@ struct CastOpInterface
133
143
}
134
144
};
135
145
146
+ // / Verifies that the indices on load/store ops are in-bounds of the memref's
147
+ // / index space: 0 <= index#i < dim#i
148
+ template <typename LoadStoreOp>
149
+ struct LoadStoreOpInterface
150
+ : public RuntimeVerifiableOpInterface::ExternalModel<
151
+ LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
152
+ void generateRuntimeVerification (Operation *op, OpBuilder &builder,
153
+ Location loc) const {
154
+ auto loadStoreOp = cast<LoadStoreOp>(op);
155
+
156
+ auto memref = loadStoreOp.getMemref ();
157
+ auto rank = memref.getType ().getRank ();
158
+ if (rank == 0 ) {
159
+ return ;
160
+ }
161
+ auto indices = loadStoreOp.getIndices ();
162
+
163
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
164
+ Value assertCond;
165
+ for (auto i : llvm::seq<int64_t >(0 , rank)) {
166
+ auto index = indices[i];
167
+
168
+ auto dimOp = builder.createOrFold <memref::DimOp>(loc, memref, i);
169
+
170
+ auto geLow = builder.createOrFold <arith::CmpIOp>(
171
+ loc, arith::CmpIPredicate::sge, index , zero);
172
+ auto ltHigh = builder.createOrFold <arith::CmpIOp>(
173
+ loc, arith::CmpIPredicate::slt, index , dimOp);
174
+ auto andOp = builder.createOrFold <arith::AndIOp>(loc, geLow, ltHigh);
175
+
176
+ assertCond =
177
+ i > 0 ? builder.createOrFold <arith::AndIOp>(loc, assertCond, andOp)
178
+ : andOp;
179
+ }
180
+ builder.create <cf::AssertOp>(
181
+ loc, assertCond, generateErrorMessage (op, " out-of-bounds access" ));
182
+ }
183
+ };
184
+
185
+ // / Compute the linear index for the provided strided layout and indices.
186
+ Value computeLinearIndex (OpBuilder &builder, Location loc, OpFoldResult offset,
187
+ ArrayRef<OpFoldResult> strides,
188
+ ArrayRef<OpFoldResult> indices) {
189
+ auto [expr, values] = computeLinearIndex (offset, strides, indices);
190
+ auto index =
191
+ affine::makeComposedFoldedAffineApply (builder, loc, expr, values);
192
+ return getValueOrCreateConstantIndexOp (builder, loc, index );
193
+ }
194
+
195
+ // / Returns two Values representing the bounds of the provided strided layout
196
+ // / metadata. The bounds are returned as a half open interval -- [low, high).
197
+ std::pair<Value, Value> computeLinearBounds (OpBuilder &builder, Location loc,
198
+ OpFoldResult offset,
199
+ ArrayRef<OpFoldResult> strides,
200
+ ArrayRef<OpFoldResult> sizes) {
201
+ auto zeros = SmallVector<int64_t >(sizes.size (), 0 );
202
+ auto indices = getAsIndexOpFoldResult (builder.getContext (), zeros);
203
+ auto lowerBound = computeLinearIndex (builder, loc, offset, strides, indices);
204
+ auto upperBound = computeLinearIndex (builder, loc, offset, strides, sizes);
205
+ return {lowerBound, upperBound};
206
+ }
207
+
208
+ // / Returns two Values representing the bounds of the memref. The bounds are
209
+ // / returned as a half open interval -- [low, high).
210
+ std::pair<Value, Value> computeLinearBounds (OpBuilder &builder, Location loc,
211
+ TypedValue<BaseMemRefType> memref) {
212
+ auto runtimeMetadata = builder.create <ExtractStridedMetadataOp>(loc, memref);
213
+ auto offset = runtimeMetadata.getConstifiedMixedOffset ();
214
+ auto strides = runtimeMetadata.getConstifiedMixedStrides ();
215
+ auto sizes = runtimeMetadata.getConstifiedMixedSizes ();
216
+ return computeLinearBounds (builder, loc, offset, strides, sizes);
217
+ }
218
+
219
+ // / Verifies that the linear bounds of a reinterpret_cast op are within the
220
+ // / linear bounds of the base memref: low >= baseLow && high <= baseHigh
221
+ struct ReinterpretCastOpInterface
222
+ : public RuntimeVerifiableOpInterface::ExternalModel<
223
+ ReinterpretCastOpInterface, ReinterpretCastOp> {
224
+ void generateRuntimeVerification (Operation *op, OpBuilder &builder,
225
+ Location loc) const {
226
+ auto reinterpretCast = cast<ReinterpretCastOp>(op);
227
+ auto baseMemref = reinterpretCast.getSource ();
228
+ auto resultMemref =
229
+ cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult ());
230
+
231
+ builder.setInsertionPointAfter (op);
232
+
233
+ // Compute the linear bounds of the base memref
234
+ auto [baseLow, baseHigh] = computeLinearBounds (builder, loc, baseMemref);
235
+
236
+ // Compute the linear bounds of the resulting memref
237
+ auto [low, high] = computeLinearBounds (builder, loc, resultMemref);
238
+
239
+ // Check low >= baseLow
240
+ auto geLow = builder.createOrFold <arith::CmpIOp>(
241
+ loc, arith::CmpIPredicate::sge, low, baseLow);
242
+
243
+ // Check high <= baseHigh
244
+ auto leHigh = builder.createOrFold <arith::CmpIOp>(
245
+ loc, arith::CmpIPredicate::sle, high, baseHigh);
246
+
247
+ auto assertCond = builder.createOrFold <arith::AndIOp>(loc, geLow, leHigh);
248
+
249
+ builder.create <cf::AssertOp>(
250
+ loc, assertCond,
251
+ generateErrorMessage (
252
+ op,
253
+ " result of reinterpret_cast is out-of-bounds of the base memref" ));
254
+ }
255
+ };
256
+
257
+ // / Verifies that the linear bounds of a subview op are within the linear bounds
258
+ // / of the base memref: low >= baseLow && high <= baseHigh
259
+ // / TODO: This is not yet a full runtime verification of subview. For example,
260
+ // / consider:
261
+ // / %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
262
+ // / memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
263
+ // / : memref<?x?xf32> to memref<?x?xf32>
264
+ // / The subview is in-bounds of the entire base memref but the first dimension
265
+ // / is out-of-bounds. Future work would verify the bounds on a per-dimension
266
+ // / basis.
267
+ struct SubViewOpInterface
268
+ : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
269
+ SubViewOp> {
270
+ void generateRuntimeVerification (Operation *op, OpBuilder &builder,
271
+ Location loc) const {
272
+ auto subView = cast<SubViewOp>(op);
273
+ auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource ());
274
+ auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult ());
275
+
276
+ builder.setInsertionPointAfter (op);
277
+
278
+ // Compute the linear bounds of the base memref
279
+ auto [baseLow, baseHigh] = computeLinearBounds (builder, loc, baseMemref);
280
+
281
+ // Compute the linear bounds of the resulting memref
282
+ auto [low, high] = computeLinearBounds (builder, loc, resultMemref);
283
+
284
+ // Check low >= baseLow
285
+ auto geLow = builder.createOrFold <arith::CmpIOp>(
286
+ loc, arith::CmpIPredicate::sge, low, baseLow);
287
+
288
+ // Check high <= baseHigh
289
+ auto leHigh = builder.createOrFold <arith::CmpIOp>(
290
+ loc, arith::CmpIPredicate::sle, high, baseHigh);
291
+
292
+ auto assertCond = builder.createOrFold <arith::AndIOp>(loc, geLow, leHigh);
293
+
294
+ builder.create <cf::AssertOp>(
295
+ loc, assertCond,
296
+ generateErrorMessage (op,
297
+ " subview is out-of-bounds of the base memref" ));
298
+ }
299
+ };
300
+
136
301
struct ExpandShapeOpInterface
137
302
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
138
303
ExpandShapeOp> {
@@ -183,8 +348,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
183
348
registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
184
349
CastOp::attachInterface<CastOpInterface>(*ctx);
185
350
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
351
+ LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
352
+ ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
353
+ StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
354
+ SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
186
355
187
356
// Load additional dialects of which ops may get created.
188
- ctx->loadDialect <arith::ArithDialect, cf::ControlFlowDialect>();
357
+ ctx->loadDialect <affine::AffineDialect, arith::ArithDialect,
358
+ cf::ControlFlowDialect>();
189
359
});
190
360
}
0 commit comments