Skip to content

Commit d980384

Browse files
authored
[mlir][emitc] Add op modelling C expressions (#71631)
Add an emitc.expression operation that models C expressions, and provide transforms to form and fold expressions. The translator emits the body of emitc.expression ops as a single C expression. This expression is emitted by default as the RHS of an EmitC SSA value, but if possible, expressions with a single use that is not another expression are instead inlined. Specific expression's inlining can be fine tuned by lowering passes and transforms.
1 parent b01adc6 commit d980384

File tree

18 files changed

+1077
-37
lines changed

18 files changed

+1077
-37
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
1919
include "mlir/Interfaces/CastInterfaces.td"
2020
include "mlir/Interfaces/ControlFlowInterfaces.td"
2121
include "mlir/Interfaces/SideEffectInterfaces.td"
22+
include "mlir/IR/RegionKindInterface.td"
2223

2324
//===----------------------------------------------------------------------===//
2425
// EmitC op definitions
@@ -247,6 +248,83 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
247248
let results = (outs FloatIntegerIndexOrOpaqueType);
248249
}
249250

251+
def EmitC_ExpressionOp : EmitC_Op<"expression",
252+
[HasOnlyGraphRegion, SingleBlockImplicitTerminator<"emitc::YieldOp">,
253+
NoRegionArguments]> {
254+
let summary = "Expression operation";
255+
let description = [{
256+
The `expression` operation returns a single SSA value which is yielded by
257+
its single-basic-block region. The operation doesn't take any arguments.
258+
259+
As the operation is to be emitted as a C expression, the operations within
260+
its body must form a single Def-Use tree of emitc ops whose result is
261+
yielded by a terminating `yield`.
262+
263+
Example:
264+
265+
```mlir
266+
%r = emitc.expression : () -> i32 {
267+
%0 = emitc.add %a, %b : (i32, i32) -> i32
268+
%1 = emitc.call "foo"(%0) : () -> i32
269+
%2 = emitc.add %c, %d : (i32, i32) -> i32
270+
%3 = emitc.mul %1, %2 : (i32, i32) -> i32
271+
yield %3
272+
}
273+
```
274+
275+
May be emitted as
276+
277+
```c++
278+
int32_t v7 = foo(v1 + v2) * (v3 + v4);
279+
```
280+
281+
The operations allowed within expression body are emitc.add, emitc.apply,
282+
emitc.call, emitc.cast, emitc.cmp, emitc.div, emitc.mul, emitc.rem and
283+
emitc.sub.
284+
285+
When specified, the optional `do_not_inline` indicates that the expression is
286+
to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
287+
definition. Otherwise, the expression may be emitted inline, i.e. directly
288+
at its use.
289+
}];
290+
291+
let arguments = (ins UnitAttr:$do_not_inline);
292+
let results = (outs AnyType:$result);
293+
let regions = (region SizedRegion<1>:$region);
294+
295+
let hasVerifier = 1;
296+
let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
297+
298+
let extraClassDeclaration = [{
299+
static bool isCExpression(Operation &op) {
300+
return isa<emitc::AddOp, emitc::ApplyOp, emitc::CallOpaqueOp,
301+
emitc::CastOp, emitc::CmpOp, emitc::DivOp, emitc::MulOp,
302+
emitc::RemOp, emitc::SubOp>(op);
303+
}
304+
bool hasSideEffects() {
305+
auto predicate = [](Operation &op) {
306+
assert(isCExpression(op) && "Expected a C expression");
307+
// Conservatively assume calls to read and write memory.
308+
if (isa<emitc::CallOpaqueOp>(op))
309+
return true;
310+
// De-referencing reads modifiable memory, address-taking has no
311+
// side-effect.
312+
auto applyOp = dyn_cast<emitc::ApplyOp>(op);
313+
if (applyOp)
314+
return applyOp.getApplicableOperator() == "*";
315+
// Any operation using variables is assumed to have a side effect of
316+
// reading memory mutable by emitc::assign ops.
317+
return llvm::any_of(op.getOperands(), [](Value operand) {
318+
Operation *def = operand.getDefiningOp();
319+
return def && isa<emitc::VariableOp>(def);
320+
});
321+
};
322+
return llvm::any_of(getRegion().front().without_terminator(), predicate);
323+
};
324+
Operation *getRootOp();
325+
}];
326+
}
327+
250328
def EmitC_ForOp : EmitC_Op<"for",
251329
[AllTypesMatch<["lowerBound", "upperBound", "step"]>,
252330
SingleBlockImplicitTerminator<"emitc::YieldOp">,
@@ -494,18 +572,24 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
494572
}
495573

496574
def EmitC_YieldOp : EmitC_Op<"yield",
497-
[Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
575+
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
498576
let summary = "block termination operation";
499577
let description = [{
500-
"yield" terminates blocks within EmitC control-flow operations. Since
501-
control-flow constructs in C do not return values, this operation doesn't
502-
take any arguments.
578+
"yield" terminates its parent EmitC op's region, optionally yielding
579+
an SSA value. The semantics of how the values are yielded is defined by the
580+
parent operation.
581+
If "yield" has an operand, the operand must match the parent operation's
582+
result. If the parent operation defines no values, then the "emitc.yield"
583+
may be left out in the custom syntax and the builders will insert one
584+
implicitly. Otherwise, it has to be present in the syntax to indicate which
585+
value is yielded.
503586
}];
504587

505-
let arguments = (ins);
588+
let arguments = (ins Optional<AnyType>:$result);
506589
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
507590

508-
let assemblyFormat = [{ attr-dict }];
591+
let hasVerifier = 1;
592+
let assemblyFormat = [{ attr-dict ($result^ `:` type($result))? }];
509593
}
510594

511595
def EmitC_IfOp : EmitC_Op<"if",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC)
3+
add_public_tablegen_target(MLIREmitCTransformsIncGen)
4+
5+
add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
10+
#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
11+
12+
#include "mlir/Pass/Pass.h"
13+
14+
namespace mlir {
15+
namespace emitc {
16+
17+
//===----------------------------------------------------------------------===//
18+
// Passes
19+
//===----------------------------------------------------------------------===//
20+
21+
/// Creates an instance of the C-style expressions forming pass.
22+
std::unique_ptr<Pass> createFormExpressionsPass();
23+
24+
//===----------------------------------------------------------------------===//
25+
// Registration
26+
//===----------------------------------------------------------------------===//
27+
28+
/// Generate the code for registering passes.
29+
#define GEN_PASS_REGISTRATION
30+
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
31+
32+
} // namespace emitc
33+
} // namespace mlir
34+
35+
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
10+
#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def FormExpressions : Pass<"form-expressions"> {
15+
let summary = "Form C-style expressions from C-operator ops";
16+
let description = [{
17+
The pass wraps emitc ops modelling C operators in emitc.expression ops and
18+
then folds single-use expressions into their users where possible.
19+
}];
20+
let constructor = "mlir::emitc::createFormExpressionsPass()";
21+
let dependentDialects = ["emitc::EmitCDialect"];
22+
}
23+
24+
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===- Transforms.h - EmitC transformations as patterns --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
10+
#define MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
11+
12+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
15+
namespace mlir {
16+
namespace emitc {
17+
18+
//===----------------------------------------------------------------------===//
19+
// Expression transforms
20+
//===----------------------------------------------------------------------===//
21+
22+
ExpressionOp createExpression(Operation *op, OpBuilder &builder);
23+
24+
//===----------------------------------------------------------------------===//
25+
// Populate functions
26+
//===----------------------------------------------------------------------===//
27+
28+
/// Populates `patterns` with expression-related patterns.
29+
void populateExpressionPatterns(RewritePatternSet &patterns);
30+
31+
} // namespace emitc
32+
} // namespace mlir
33+
34+
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H

mlir/include/mlir/InitAllPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/Async/Passes.h"
2424
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
2525
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
26+
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
2627
#include "mlir/Dialect/Func/Transforms/Passes.h"
2728
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
2829
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -87,6 +88,7 @@ inline void registerAllPasses() {
8788
vector::registerVectorPasses();
8889
arm_sme::registerArmSMEPasses();
8990
arm_sve::registerArmSVEPasses();
91+
emitc::registerEmitCPasses();
9092

9193
// Dialect pipelines
9294
bufferization::registerBufferizationPipelines();

mlir/lib/Dialect/EmitC/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,50 @@ LogicalResult emitc::ConstantOp::verify() {
189189

190190
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
191191

192+
//===----------------------------------------------------------------------===//
193+
// ExpressionOp
194+
//===----------------------------------------------------------------------===//
195+
196+
Operation *ExpressionOp::getRootOp() {
197+
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
198+
Value yieldedValue = yieldOp.getResult();
199+
Operation *rootOp = yieldedValue.getDefiningOp();
200+
assert(rootOp && "Yielded value not defined within expression");
201+
return rootOp;
202+
}
203+
204+
LogicalResult ExpressionOp::verify() {
205+
Type resultType = getResult().getType();
206+
Region &region = getRegion();
207+
208+
Block &body = region.front();
209+
210+
if (!body.mightHaveTerminator())
211+
return emitOpError("must yield a value at termination");
212+
213+
auto yield = cast<YieldOp>(body.getTerminator());
214+
Value yieldResult = yield.getResult();
215+
216+
if (!yieldResult)
217+
return emitOpError("must yield a value at termination");
218+
219+
Type yieldType = yieldResult.getType();
220+
221+
if (resultType != yieldType)
222+
return emitOpError("requires yielded type to match return type");
223+
224+
for (Operation &op : region.front().without_terminator()) {
225+
if (!isCExpression(op))
226+
return emitOpError("contains an unsupported operation");
227+
if (op.getNumResults() != 1)
228+
return emitOpError("requires exactly one result for each operation");
229+
if (!op.getResult(0).hasOneUse())
230+
return emitOpError("requires exactly one use for each operation");
231+
}
232+
233+
return success();
234+
}
235+
192236
//===----------------------------------------------------------------------===//
193237
// ForOp
194238
//===----------------------------------------------------------------------===//
@@ -530,6 +574,23 @@ LogicalResult emitc::VariableOp::verify() {
530574
return success();
531575
}
532576

577+
//===----------------------------------------------------------------------===//
578+
// YieldOp
579+
//===----------------------------------------------------------------------===//
580+
581+
LogicalResult emitc::YieldOp::verify() {
582+
Value result = getResult();
583+
Operation *containingOp = getOperation()->getParentOp();
584+
585+
if (result && containingOp->getNumResults() != 1)
586+
return emitOpError() << "yields a value not returned by parent";
587+
588+
if (!result && containingOp->getNumResults() != 0)
589+
return emitOpError() << "does not yield a value to be returned by parent";
590+
591+
return success();
592+
}
593+
533594
//===----------------------------------------------------------------------===//
534595
// TableGen'd op method definitions
535596
//===----------------------------------------------------------------------===//
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_dialect_library(MLIREmitCTransforms
2+
Transforms.cpp
3+
FormExpressions.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
7+
8+
DEPENDS
9+
MLIREmitCTransformsIncGen
10+
11+
LINK_LIBS PUBLIC
12+
MLIRIR
13+
MLIRPass
14+
MLIREmitCDialect
15+
MLIRTransforms
16+
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===- FormExpressions.cpp - Form C-style expressions --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass that forms EmitC operations modeling C operators
10+
// into C-style expressions using the emitc.expression op.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
15+
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
16+
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
namespace mlir {
20+
namespace emitc {
21+
#define GEN_PASS_DEF_FORMEXPRESSIONS
22+
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
23+
} // namespace emitc
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
using namespace emitc;
28+
29+
namespace {
30+
struct FormExpressionsPass
31+
: public emitc::impl::FormExpressionsBase<FormExpressionsPass> {
32+
void runOnOperation() override {
33+
Operation *rootOp = getOperation();
34+
MLIRContext *context = rootOp->getContext();
35+
36+
// Wrap each C operator op with an expression op.
37+
OpBuilder builder(context);
38+
auto matchFun = [&](Operation *op) {
39+
if (emitc::ExpressionOp::isCExpression(*op))
40+
createExpression(op, builder);
41+
};
42+
rootOp->walk(matchFun);
43+
44+
// Fold expressions where possible.
45+
RewritePatternSet patterns(context);
46+
populateExpressionPatterns(patterns);
47+
48+
if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns))))
49+
return signalPassFailure();
50+
}
51+
52+
void getDependentDialects(DialectRegistry &registry) const override {
53+
registry.insert<emitc::EmitCDialect>();
54+
}
55+
};
56+
} // namespace
57+
58+
std::unique_ptr<Pass> mlir::emitc::createFormExpressionsPass() {
59+
return std::make_unique<FormExpressionsPass>();
60+
}

0 commit comments

Comments
 (0)