Skip to content

[mlir][LLVM] LLVMTypeConverter: Tighten materialization checks #116532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});

// Helper function that checks if the given value range is a bare pointer.
auto isBarePointer = [](ValueRange values) {
return values.size() == 1 &&
isa<LLVM::LLVMPointerType>(values.front().getType());
};

// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
Expand All @@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
// Note: Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
return Value();
}
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
Expand All @@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
if (!barePtr)
return Value();
Block *block = barePtr.getOwner();
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return Value();
if (isBarePointer(inputs)) {
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else {
} else if (TypeRange(inputs) ==
getMemRefDescriptorFields(resultType,
/*unpackAggregates=*/true)) {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
} else {
// The inputs are neither a bare pointer nor an unpacked memref
// descriptor. This materialization function cannot be used.
return Value();
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
Expand Down
57 changes: 57 additions & 0 deletions mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file

// Test the argument materializer for ranked MemRef types.

// CHECK-LABEL: func @construct_ranked_memref_descriptor(
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-COUNT-7: llvm.insertvalue
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
%0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
return
}

// -----

// The argument materializer for ranked MemRef types is called with incorrect
// input types. Make sure that the materializer is skipped and we do not
// generate invalid IR.

// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
// CHECK: "test.legal_op"(%[[cast]])
func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
return
}

// -----

// Test the argument materializer for unranked MemRef types.

// CHECK-LABEL: func @construct_unranked_memref_descriptor(
// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
// CHECK-COUNT-2: llvm.insertvalue
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
%0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
"test.legal_op"(%0) : (memref<*xf32>) -> ()
return
}

// -----

// The argument materializer for unranked MemRef types is called with incorrect
// input types. Make sure that the materializer is skipped and we do not
// generate invalid IR.

// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
// CHECK: "test.legal_op"(%[[cast]])
func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
"test.legal_op"(%0) : (memref<*xf32>) -> ()
return
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/LLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLLVMTestPasses
TestLowerToLLVM.cpp
TestPatterns.cpp

EXCLUDE_FROM_LIBMLIR

Expand Down
77 changes: 77 additions & 0 deletions mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {

/// Replace this op (which is expected to have 1 result) with the operands.
struct TestDirectReplacementOp : public ConversionPattern {
TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if (op->getNumResults() != 1)
return failure();
rewriter.replaceOpWithMultiple(op, {operands});
return success();
}
};

struct TestLLVMLegalizePatternsPass
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)

StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
StringRef getDescription() const final {
return "Run LLVM dialect legalization patterns";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
LLVMTypeConverter converter(ctx);
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);

// Define the conversion target used for the test.
ConversionTarget target(*ctx);
target.addLegalOp(OperationName("test.legal_op", ctx));

// Handle a partial conversion.
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
config.unlegalizedOps = &unlegalizedOps;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config)))
getOperation()->emitError() << "applyPartialConversion failed";
}
};
} // namespace

//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//

namespace mlir {
namespace test {
void registerTestLLVMLegalizePatternsPass() {
PassRegistration<TestLLVMLegalizePatternsPass>();
}
} // namespace test
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
void registerTestLLVMLegalizePatternsPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
Expand Down Expand Up @@ -250,6 +251,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
mlir::test::registerTestLLVMLegalizePatternsPass();
mlir::test::registerTestLoopFusion();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
Expand Down
Loading