Skip to content

Commit a0ef12c

Browse files
[mlir][LLVM] LLVMTypeConverter: Tighten materialization checks (#116532)
This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a `MemRefType`/`UnrankedMemRefType` from the unpacked elements of a MemRef descriptor or from a bare pointer. The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect. This commit also drops a check around bare pointer materializations: ``` // This is a bare pointer. We allow bare pointers only for function entry // blocks. ``` This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize regardless of the input format.
1 parent ed1d90c commit a0ef12c

File tree

5 files changed

+154
-15
lines changed

5 files changed

+154
-15
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Helper function that checks if the given value range is a bare pointer.
157+
auto isBarePointer = [](ValueRange values) {
158+
return values.size() == 1 &&
159+
isa<LLVM::LLVMPointerType>(values.front().getType());
160+
};
161+
156162
// Argument materializations convert from the new block argument types
157163
// (multiple SSA values that make up a memref descriptor) back to the
158164
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
161167
addArgumentMaterialization([&](OpBuilder &builder,
162168
UnrankedMemRefType resultType,
163169
ValueRange inputs, Location loc) {
164-
if (inputs.size() == 1) {
165-
// Bare pointers are not supported for unranked memrefs because a
166-
// memref descriptor cannot be built just from a bare pointer.
170+
// Note: Bare pointers are not supported for unranked memrefs because a
171+
// memref descriptor cannot be built just from a bare pointer.
172+
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
167173
return Value();
168-
}
169174
Value desc =
170175
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171176
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
177182
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178183
ValueRange inputs, Location loc) {
179184
Value desc;
180-
if (inputs.size() == 1) {
181-
// This is a bare pointer. We allow bare pointers only for function entry
182-
// blocks.
183-
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
184-
if (!barePtr)
185-
return Value();
186-
Block *block = barePtr.getOwner();
187-
if (!block->isEntryBlock() ||
188-
!isa<FunctionOpInterface>(block->getParentOp()))
189-
return Value();
185+
if (isBarePointer(inputs)) {
190186
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191187
inputs[0]);
192-
} else {
188+
} else if (TypeRange(inputs) ==
189+
getMemRefDescriptorFields(resultType,
190+
/*unpackAggregates=*/true)) {
193191
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
192+
} else {
193+
// The inputs are neither a bare pointer nor an unpacked memref
194+
// descriptor. This materialization function cannot be used.
195+
return Value();
194196
}
195197
// An argument materialization must return a value of type `resultType`,
196198
// so insert a cast from the memref descriptor type (!llvm.struct) to the
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
2+
3+
// Test the argument materializer for ranked MemRef types.
4+
5+
// CHECK-LABEL: func @construct_ranked_memref_descriptor(
6+
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
7+
// CHECK-COUNT-7: llvm.insertvalue
8+
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
9+
func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
10+
%0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
11+
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
12+
return
13+
}
14+
15+
// -----
16+
17+
// The argument materializer for ranked MemRef types is called with incorrect
18+
// input types. Make sure that the materializer is skipped and we do not
19+
// generate invalid IR.
20+
21+
// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
22+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
23+
// CHECK: "test.legal_op"(%[[cast]])
24+
func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
25+
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
26+
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
27+
return
28+
}
29+
30+
// -----
31+
32+
// Test the argument materializer for unranked MemRef types.
33+
34+
// CHECK-LABEL: func @construct_unranked_memref_descriptor(
35+
// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
36+
// CHECK-COUNT-2: llvm.insertvalue
37+
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
38+
func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
39+
%0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
40+
"test.legal_op"(%0) : (memref<*xf32>) -> ()
41+
return
42+
}
43+
44+
// -----
45+
46+
// The argument materializer for unranked MemRef types is called with incorrect
47+
// input types. Make sure that the materializer is skipped and we do not
48+
// generate invalid IR.
49+
50+
// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
51+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
52+
// CHECK: "test.legal_op"(%[[cast]])
53+
func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
54+
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
55+
"test.legal_op"(%0) : (memref<*xf32>) -> ()
56+
return
57+
}

mlir/test/lib/Dialect/LLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRLLVMTestPasses
33
TestLowerToLLVM.cpp
4+
TestPatterns.cpp
45

56
EXCLUDE_FROM_LIBMLIR
67

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
19+
/// Replace this op (which is expected to have 1 result) with the operands.
20+
struct TestDirectReplacementOp : public ConversionPattern {
21+
TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
22+
: ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
23+
LogicalResult
24+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
25+
ConversionPatternRewriter &rewriter) const final {
26+
if (op->getNumResults() != 1)
27+
return failure();
28+
rewriter.replaceOpWithMultiple(op, {operands});
29+
return success();
30+
}
31+
};
32+
33+
struct TestLLVMLegalizePatternsPass
34+
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
36+
37+
StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
38+
StringRef getDescription() const final {
39+
return "Run LLVM dialect legalization patterns";
40+
}
41+
42+
void getDependentDialects(DialectRegistry &registry) const override {
43+
registry.insert<LLVM::LLVMDialect>();
44+
}
45+
46+
void runOnOperation() override {
47+
MLIRContext *ctx = &getContext();
48+
LLVMTypeConverter converter(ctx);
49+
mlir::RewritePatternSet patterns(ctx);
50+
patterns.add<TestDirectReplacementOp>(ctx, converter);
51+
52+
// Define the conversion target used for the test.
53+
ConversionTarget target(*ctx);
54+
target.addLegalOp(OperationName("test.legal_op", ctx));
55+
56+
// Handle a partial conversion.
57+
DenseSet<Operation *> unlegalizedOps;
58+
ConversionConfig config;
59+
config.unlegalizedOps = &unlegalizedOps;
60+
if (failed(applyPartialConversion(getOperation(), target,
61+
std::move(patterns), config)))
62+
getOperation()->emitError() << "applyPartialConversion failed";
63+
}
64+
};
65+
} // namespace
66+
67+
//===----------------------------------------------------------------------===//
68+
// PassRegistration
69+
//===----------------------------------------------------------------------===//
70+
71+
namespace mlir {
72+
namespace test {
73+
void registerTestLLVMLegalizePatternsPass() {
74+
PassRegistration<TestLLVMLegalizePatternsPass>();
75+
}
76+
} // namespace test
77+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps();
113113
void registerTestLinalgTransforms();
114114
void registerTestLivenessAnalysisPass();
115115
void registerTestLivenessPass();
116+
void registerTestLLVMLegalizePatternsPass();
116117
void registerTestLoopFusion();
117118
void registerTestLoopMappingPass();
118119
void registerTestLoopUnrollingPass();
@@ -250,6 +251,7 @@ void registerTestPasses() {
250251
mlir::test::registerTestLinalgTransforms();
251252
mlir::test::registerTestLivenessAnalysisPass();
252253
mlir::test::registerTestLivenessPass();
254+
mlir::test::registerTestLLVMLegalizePatternsPass();
253255
mlir::test::registerTestLoopFusion();
254256
mlir::test::registerTestLoopMappingPass();
255257
mlir::test::registerTestLoopUnrollingPass();

0 commit comments

Comments
 (0)