Skip to content

Commit 033ee54

Browse files
matthias-springerGeorgeARM
authored andcommitted
[mlir][memref] Remove runtime verification for memref.reinterpret_cast (llvm#132547)
The runtime verification code used to verify that the result of a `memref.reinterpret_cast` is in-bounds with respect to the source memref. This is incorrect: `memref.reinterpret_cast` allows users to construct almost arbitrary memref descriptors and there is no correctness expectation. This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of `memref.reinterpret_cast` does not verify in-bounds semantics either.
1 parent cc6f310 commit 033ee54

File tree

2 files changed

+1
-147
lines changed

2 files changed

+1
-147
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -255,78 +255,6 @@ struct LoadStoreOpInterface
255255
}
256256
};
257257

258-
/// Compute the linear index for the provided strided layout and indices.
259-
Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
260-
ArrayRef<OpFoldResult> strides,
261-
ArrayRef<OpFoldResult> indices) {
262-
auto [expr, values] = computeLinearIndex(offset, strides, indices);
263-
auto index =
264-
affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
265-
return getValueOrCreateConstantIndexOp(builder, loc, index);
266-
}
267-
268-
/// Returns two Values representing the bounds of the provided strided layout
269-
/// metadata. The bounds are returned as a half open interval -- [low, high).
270-
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
271-
OpFoldResult offset,
272-
ArrayRef<OpFoldResult> strides,
273-
ArrayRef<OpFoldResult> sizes) {
274-
auto zeros = SmallVector<int64_t>(sizes.size(), 0);
275-
auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
276-
auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
277-
auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
278-
return {lowerBound, upperBound};
279-
}
280-
281-
/// Returns two Values representing the bounds of the memref. The bounds are
282-
/// returned as a half open interval -- [low, high).
283-
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
284-
TypedValue<BaseMemRefType> memref) {
285-
auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
286-
auto offset = runtimeMetadata.getConstifiedMixedOffset();
287-
auto strides = runtimeMetadata.getConstifiedMixedStrides();
288-
auto sizes = runtimeMetadata.getConstifiedMixedSizes();
289-
return computeLinearBounds(builder, loc, offset, strides, sizes);
290-
}
291-
292-
/// Verifies that the linear bounds of a reinterpret_cast op are within the
293-
/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
294-
struct ReinterpretCastOpInterface
295-
: public RuntimeVerifiableOpInterface::ExternalModel<
296-
ReinterpretCastOpInterface, ReinterpretCastOp> {
297-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
298-
Location loc) const {
299-
auto reinterpretCast = cast<ReinterpretCastOp>(op);
300-
auto baseMemref = reinterpretCast.getSource();
301-
auto resultMemref =
302-
cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
303-
304-
builder.setInsertionPointAfter(op);
305-
306-
// Compute the linear bounds of the base memref
307-
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
308-
309-
// Compute the linear bounds of the resulting memref
310-
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
311-
312-
// Check low >= baseLow
313-
auto geLow = builder.createOrFold<arith::CmpIOp>(
314-
loc, arith::CmpIPredicate::sge, low, baseLow);
315-
316-
// Check high <= baseHigh
317-
auto leHigh = builder.createOrFold<arith::CmpIOp>(
318-
loc, arith::CmpIPredicate::sle, high, baseHigh);
319-
320-
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
321-
322-
builder.create<cf::AssertOp>(
323-
loc, assertCond,
324-
RuntimeVerifiableOpInterface::generateErrorMessage(
325-
op,
326-
"result of reinterpret_cast is out-of-bounds of the base memref"));
327-
}
328-
};
329-
330258
struct SubViewOpInterface
331259
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
332260
SubViewOp> {
@@ -434,9 +362,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
434362
GenericAtomicRMWOp::attachInterface<
435363
LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
436364
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
437-
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
438365
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
439366
SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
367+
// Note: There is nothing to verify for ReinterpretCastOp.
440368

441369
// Load additional dialects of which ops may get created.
442370
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,

mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)