Skip to content

Commit 4cb2ef4

Browse files
authored
[mlir] add a chapter on matchers to the transform dialect tutorial (#76725)
These operations has been available for a while, but were not described in the tutorial. Add a new chapter on using and defining match operations.
1 parent 633d918 commit 4cb2ef4

File tree

16 files changed

+1375
-3
lines changed

16 files changed

+1375
-3
lines changed

mlir/docs/Tutorials/transform/Ch4.md

Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.

mlir/docs/Tutorials/transform/_index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ The tutorial is divided into the following chapters.
2626
- [Chapter #1](Ch1.md): Combining Existing Transformations
2727
- [Chapter #2](Ch2.md): Adding a Simple New Transformation Operation
2828
- [Chapter #3](Ch3.md): More than Simple Transform Operations
29+
- [Chapter #4](Ch4.md): Matching Payload with Transform Operations
2930
- [Chapter H](ChH.md): Reproducing Halide Schedule
3031

3132
The code corresponding to this tutorial is located under

mlir/examples/transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ add_custom_target(TransformExample)
22

33
add_subdirectory(Ch2)
44
add_subdirectory(Ch3)
5+
add_subdirectory(Ch4)

mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This is the top-level file for the Transform dialect tutorial chapter 2.
9+
// This is the top-level file for the Transform dialect tutorial chapter 3.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# For a better top-level template to copy, see examples/standalone.
2+
3+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
4+
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
5+
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
6+
7+
add_subdirectory(include)
8+
add_subdirectory(lib)
9+
10+
add_dependencies(TransformExample transform-opt-ch4)
11+
add_llvm_example(transform-opt-ch4
12+
transform-opt/transform-opt.cpp)
13+
14+
target_link_libraries(transform-opt-ch4
15+
PRIVATE
16+
MLIRIR
17+
MLIRMlirOptMain
18+
MLIRSideEffectInterfaces
19+
MLIRTransformDialectTransforms
20+
MyExtensionCh4
21+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Tell Tablegen to use MyExtension.td as input.
2+
set(LLVM_TARGET_DEFINITIONS MyExtension.td)
3+
4+
# Ask Tablegen to generate op declarations and definitions from ODS.
5+
mlir_tablegen(MyExtension.h.inc -gen-op-decls)
6+
mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
7+
8+
# Add a CMakeTarget we can depend on to ensure the generation happens before the
9+
# compilation.
10+
add_public_tablegen_target(MyExtensionCh4IncGen)
11+
12+
# Don't forget to generate the documentation, this will produce a
13+
# MyExtensionCh4.md under Tutorials/transform
14+
add_mlir_doc(MyExtension MyExtensionCh4 Tutorials/transform/ -gen-op-doc)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines Transform dialect extension operations used in the
10+
// Chapter 4 of the Transform dialect tutorial.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Bytecode/BytecodeOpInterface.h"
15+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
17+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
18+
19+
namespace mlir {
20+
class CallOpInterface;
21+
namespace func {
22+
class CallOp;
23+
} // namespace func
24+
} // namespace mlir
25+
26+
#define GET_OP_CLASSES
27+
#include "MyExtension.h.inc"
28+
29+
// Registers our Transform dialect extension.
30+
void registerMyExtension(::mlir::DialectRegistry &registry);
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines Transform dialect extension operations used in the
10+
// Chapter 4 of the Transform dialect tutorial.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MY_EXTENSION
15+
#define MY_EXTENSION
16+
17+
include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
18+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
19+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
20+
include "mlir/IR/OpBase.td"
21+
include "mlir/Interfaces/SideEffectInterfaces.td"
22+
23+
// Define the new operation. By convention, prefix its name with `match`
24+
// followed by the name of the dialect extension.
25+
def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying",
26+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
27+
DeclareOpInterfaceMethods<TransformOpInterface>,
28+
// Indicate that the operation implements MatchOpInterface in addition to
29+
// the TransformOpInterface. This interface is only used as a tag at this
30+
// point and has no methods that are mandatory to implement.
31+
MatchOpInterface,
32+
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
33+
let summary = "Succeed if any of the operands matches all nested criteria";
34+
let arguments = (ins TransformHandleTypeInterface:$op);
35+
let results = (outs TransformParamTypeInterface:$position,
36+
Variadic<Transform_AnyHandleOrParamType>:$results);
37+
38+
// Match operations can be arbitrarily complex, e.g., containing regions.
39+
let regions = (region SizedRegion<1>:$body);
40+
let hasVerifier = 1;
41+
let assemblyFormat = [{
42+
$op `:` functional-type($op, results) attr-dict-with-keyword $body
43+
}];
44+
}
45+
46+
#endif // MY_EXTENSION
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Outside examples, this should be `add_mlir_library`.
2+
add_mlir_example_library(
3+
# Library called MyExtension.
4+
MyExtensionCh4
5+
6+
# Built from the following source files.
7+
MyExtension.cpp
8+
9+
# Make includes visible without top-level path.
10+
ADDITIONAL_HEADER_DIRS
11+
${PROJECT_SOURCE_DIR}/examples/transform/Ch4/include
12+
13+
# Make sure ODS declaration and definitions are generated before compiling this.
14+
DEPENDS
15+
MyExtensionCh4IncGen
16+
17+
# Link in the transform dialect, an all generated dialects.
18+
LINK_LIBS PRIVATE
19+
MLIRTransformDialect
20+
)
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines Transform dialect extension operations used in the
10+
// Chapter 4 of the Transform dialect tutorial.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "MyExtension.h"
15+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16+
#include "llvm/Support/Debug.h"
17+
18+
#define DEBUG_TYPE_MATCHER "transform-matcher"
19+
#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
20+
#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
21+
22+
#define GET_OP_CLASSES
23+
#include "MyExtension.cpp.inc"
24+
25+
//===---------------------------------------------------------------------===//
26+
// MyExtension
27+
//===---------------------------------------------------------------------===//
28+
29+
// Define a new transform dialect extension. This uses the CRTP idiom to
30+
// identify extensions.
31+
class MyExtension
32+
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
33+
public:
34+
// The extension must derive the base constructor.
35+
using Base::Base;
36+
37+
// This function initializes the extension, similarly to `initialize` in
38+
// dialect definitions. List individual operations and dependent dialects
39+
// here.
40+
void init();
41+
};
42+
43+
void MyExtension::init() {
44+
// Register the additional match operations with the dialect similarly to
45+
// other transform operations. List all operations generated from ODS. This
46+
// call will perform additional checks that the operations implement the
47+
// transform and memory effect interfaces required by the dialect interpreter
48+
// and assert if they do not.
49+
registerTransformOps<
50+
#define GET_OP_LIST
51+
#include "MyExtension.cpp.inc"
52+
>();
53+
}
54+
55+
//===---------------------------------------------------------------------===//
56+
// HasOperandSatisfyingOp
57+
//===---------------------------------------------------------------------===//
58+
59+
/// Returns `true` if both types implement one of the interfaces provided as
60+
/// template parameters.
61+
template <typename... Tys>
62+
static bool implementSameInterface(mlir::Type t1, mlir::Type t2) {
63+
return ((llvm::isa<Tys>(t1) && llvm::isa<Tys>(t2)) || ... || false);
64+
}
65+
66+
/// Returns `true` if both types implement one of the transform dialect
67+
/// interfaces.
68+
static bool implementSameTransformInterface(mlir::Type t1, mlir::Type t2) {
69+
return implementSameInterface<
70+
mlir::transform::TransformHandleTypeInterface,
71+
mlir::transform::TransformParamTypeInterface,
72+
mlir::transform::TransformValueHandleTypeInterface>(t1, t2);
73+
}
74+
75+
// Matcher ops implement `apply` similarly to other transform ops. They are not
76+
// expected to modify payload, but use the tri-state result to signal failure or
77+
// success to match, as well as potential irrecoverable errors.
78+
mlir::DiagnosedSilenceableFailure
79+
mlir::transform::HasOperandSatisfyingOp::apply(
80+
mlir::transform::TransformRewriter &rewriter,
81+
mlir::transform::TransformResults &results,
82+
mlir::transform::TransformState &state) {
83+
// For simplicity, only handle a single payload op. Actual implementations
84+
// can use `SingleOpMatcher` trait to simplify implementation and document
85+
// this expectation.
86+
auto payloadOps = state.getPayloadOps(getOp());
87+
if (!llvm::hasSingleElement(payloadOps))
88+
return emitSilenceableError() << "expected single payload";
89+
90+
// Iterate over all operands of the payload op to see if they can be matched
91+
// using the body of this op.
92+
Operation *payload = *payloadOps.begin();
93+
for (OpOperand &operand : payload->getOpOperands()) {
94+
// Create a scope for transform values defined in the body. This corresponds
95+
// to the syntactic scope of the region attached to this op. Any values
96+
// associated with payloads from now on will be automatically dissociated
97+
// when this object is destroyed, i.e. at the end of the iteration.
98+
// Associate the block argument handle with the operand.
99+
auto matchScope = state.make_region_scope(getBody());
100+
if (failed(state.mapBlockArgument(getBody().getArgument(0),
101+
{operand.get()}))) {
102+
return DiagnosedSilenceableFailure::definiteFailure();
103+
}
104+
105+
// Iterate over all nested matchers with the current mapping and see if they
106+
// succeed.
107+
bool matchSucceeded = true;
108+
for (Operation &matcher : getBody().front().without_terminator()) {
109+
// Matcher ops are applied similarly to any other transform op.
110+
DiagnosedSilenceableFailure diag =
111+
state.applyTransform(cast<TransformOpInterface>(matcher));
112+
113+
// Definite failures are immediately propagated as they are irrecoverable.
114+
if (diag.isDefiniteFailure())
115+
return diag;
116+
117+
// On success, keep checking the remaining conditions.
118+
if (diag.succeeded())
119+
continue;
120+
121+
// Report failure-to-match for debugging purposes and stop matching this
122+
// operand.
123+
assert(diag.isSilenceableFailure());
124+
DEBUG_MATCHER(DBGS_MATCHER()
125+
<< "failed to match operand #" << operand.getOperandNumber()
126+
<< ": " << diag.getMessage());
127+
(void)diag.silence();
128+
matchSucceeded = false;
129+
break;
130+
}
131+
// If failed to match this operand, try other operands.
132+
if (!matchSucceeded)
133+
continue;
134+
135+
// If we reached this point, the matching succeeded for the current operand.
136+
// Remap the values associated with terminator operands to be associated
137+
// with op results, and also map the parameter result to the operand's
138+
// position. Note that it is safe to do here despite the end of the scope
139+
// as `results` are integrated into `state` by the interpreter after `apply`
140+
// returns rather than immediately.
141+
SmallVector<SmallVector<MappedValue>> yieldedMappings;
142+
transform::detail::prepareValueMappings(
143+
yieldedMappings, getBody().front().getTerminator()->getOperands(),
144+
state);
145+
results.setParams(getPosition().cast<OpResult>(),
146+
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
147+
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
148+
results.setMappedValues(result, mapping);
149+
return DiagnosedSilenceableFailure::success();
150+
}
151+
152+
// If we reached this point, none of the operands succeeded the match.
153+
return emitSilenceableError()
154+
<< "none of the operands satisfied the conditions";
155+
}
156+
157+
// By convention, operations implementing MatchOpInterface must not modify
158+
// payload IR and must therefore specify that they only read operand handles and
159+
// payload as their effects.
160+
void mlir::transform::HasOperandSatisfyingOp::getEffects(
161+
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) {
162+
onlyReadsPayload(effects);
163+
onlyReadsHandle(getOp(), effects);
164+
producesHandle(getPosition(), effects);
165+
producesHandle(getResults(), effects);
166+
}
167+
168+
// Verify well-formedness of the operation and emit diagnostics if it is
169+
// ill-formed.
170+
mlir::LogicalResult mlir::transform::HasOperandSatisfyingOp::verify() {
171+
mlir::Block &bodyBlock = getBody().front();
172+
if (bodyBlock.getNumArguments() != 1 ||
173+
!isa<TransformValueHandleTypeInterface>(
174+
bodyBlock.getArgument(0).getType())) {
175+
return emitOpError()
176+
<< "expects the body to have one value handle argument";
177+
}
178+
if (bodyBlock.getTerminator()->getNumOperands() != getNumResults() - 1) {
179+
return emitOpError() << "expects the body to yield "
180+
<< (getNumResults() - 1) << " values, got "
181+
<< bodyBlock.getTerminator()->getNumOperands();
182+
}
183+
for (auto &&[i, operand, result] :
184+
llvm::enumerate(bodyBlock.getTerminator()->getOperands().getTypes(),
185+
getResults().getTypes())) {
186+
if (implementSameTransformInterface(operand, result))
187+
continue;
188+
return emitOpError() << "expects terminator operand #" << i
189+
<< " and result #" << (i + 1)
190+
<< " to implement the same transform interface";
191+
}
192+
193+
for (Operation &op : bodyBlock.without_terminator()) {
194+
if (!isa<TransformOpInterface>(op) || !isa<MatchOpInterface>(op)) {
195+
InFlightDiagnostic diag = emitOpError()
196+
<< "expects body to contain match ops";
197+
diag.attachNote(op.getLoc()) << "non-match operation";
198+
return diag;
199+
}
200+
}
201+
202+
return success();
203+
}
204+
205+
void registerMyExtension(::mlir::DialectRegistry &registry) {
206+
registry.addExtensions<MyExtension>();
207+
}

0 commit comments

Comments
 (0)