Skip to content

Commit dbbdee2

Browse files
[mlir] Make the ml_program dialect allow all of its operations to be inlined. (#85479)
1 parent 426bf0c commit dbbdee2

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ add_mlir_dialect_library(MLIRMLProgramDialect
1515
MLIRControlFlowInterfaces
1616
MLIRFunctionInterfaces
1717
MLIRInferTypeOpInterface
18+
MLIRTransforms
1819
MLIRIR
1920
)

mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1010
#include "mlir/IR/DialectImplementation.h"
11+
#include "mlir/Transforms/InliningUtils.h"
1112
#include "llvm/ADT/TypeSwitch.h"
1213

1314
using namespace mlir;
@@ -24,6 +25,18 @@ using namespace mlir::ml_program;
2425
#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.cpp.inc"
2526

2627
namespace {
28+
29+
struct MLProgramInlinerInterface : public DialectInlinerInterface {
30+
using DialectInlinerInterface::DialectInlinerInterface;
31+
32+
bool isLegalToInline(Operation *, Region *, bool,
33+
IRMapping &) const override {
34+
// We have no specific opinion on whether ops defined in this dialect should
35+
// be inlined.
36+
return true;
37+
}
38+
};
39+
2740
struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface {
2841
using OpAsmDialectInterface::OpAsmDialectInterface;
2942

@@ -53,5 +66,5 @@ void ml_program::MLProgramDialect::initialize() {
5366
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
5467
>();
5568

56-
addInterfaces<MLProgramOpAsmDialectInterface>();
69+
addInterfaces<MLProgramInlinerInterface, MLProgramOpAsmDialectInterface>();
5770
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt %s -inline | FileCheck %s
2+
3+
// Verifies that regions with operations from the ml_program dialect can
4+
// be inlined.
5+
6+
ml_program.global private @global(dense<4> : tensor<4xi32>) : tensor<4xi32>
7+
8+
// CHECK: @inline_into
9+
func.func @inline_into() -> tensor<4xi32> {
10+
// CHECK-NOT: @inline_from
11+
// CHECK: ml_program.global_load_const
12+
%0 = call @inline_from() : () -> tensor<4xi32>
13+
return %0 : tensor<4xi32>
14+
}
15+
16+
func.func @inline_from() -> tensor<4xi32> {
17+
%0 = ml_program.global_load_const @global : tensor<4xi32>
18+
return %0 : tensor<4xi32>
19+
}

0 commit comments

Comments
 (0)