Skip to content

Commit d7f096e

Browse files
mmhaandykaylor
andauthored
[CIR] Upstream TernaryOp (#137184)
This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG. This is the first PR out of (probably) 3 wrt. TernaryOp. I split the patches up to make reviewing them easier. As such, this PR is only about adding the CIR operation. The next PR will be about the CodeGen bits from the C++ conditional operator and the final one will add the cir-simplify transform for TernaryOp and SelectOp. --------- Co-authored-by: Morris Hafner <[email protected]> Co-authored-by: Andy Kaylor <[email protected]>
1 parent bad8bf5 commit d7f096e

File tree

6 files changed

+286
-8
lines changed

6 files changed

+286
-8
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,9 @@ def ConditionOp : CIR_Op<"condition", [
610610
//===----------------------------------------------------------------------===//
611611

612612
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
613-
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp",
614-
"WhileOp", "ForOp", "CaseOp",
615-
"DoWhileOp"]>]> {
613+
ParentOneOf<["CaseOp", "DoWhileOp", "ForOp",
614+
"IfOp", "ScopeOp", "SwitchOp",
615+
"TernaryOp", "WhileOp"]>]> {
616616
let summary = "Represents the default branching behaviour of a region";
617617
let description = [{
618618
The `cir.yield` operation terminates regions on different CIR operations,
@@ -1462,6 +1462,63 @@ def SelectOp : CIR_Op<"select", [Pure,
14621462
}];
14631463
}
14641464

1465+
//===----------------------------------------------------------------------===//
1466+
// TernaryOp
1467+
//===----------------------------------------------------------------------===//
1468+
1469+
def TernaryOp : CIR_Op<"ternary",
1470+
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
1471+
RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
1472+
let summary = "The `cond ? a : b` C/C++ ternary operation";
1473+
let description = [{
1474+
The `cir.ternary` operation represents C/C++ ternary, much like a `select`
1475+
operation. The first argument is a `cir.bool` condition to evaluate, followed
1476+
by two regions to execute (true or false). This is different from `cir.if`
1477+
since each region is one block sized and the `cir.yield` closing the block
1478+
scope should have one argument.
1479+
1480+
`cir.ternary` also represents the GNU binary conditional operator ?: which
1481+
reuses the parent operation for both the condition and the true branch to
1482+
evaluate it only once.
1483+
1484+
Example:
1485+
1486+
```mlir
1487+
// cond = a && b;
1488+
1489+
%x = cir.ternary (%cond, true_region {
1490+
...
1491+
cir.yield %a : i32
1492+
}, false_region {
1493+
...
1494+
cir.yield %b : i32
1495+
}) -> i32
1496+
```
1497+
}];
1498+
let arguments = (ins CIR_BoolType:$cond);
1499+
let regions = (region AnyRegion:$trueRegion,
1500+
AnyRegion:$falseRegion);
1501+
let results = (outs Optional<CIR_AnyType>:$result);
1502+
1503+
let skipDefaultBuilders = 1;
1504+
let builders = [
1505+
OpBuilder<(ins "mlir::Value":$cond,
1506+
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder,
1507+
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder)
1508+
>
1509+
];
1510+
1511+
// All constraints already verified elsewhere.
1512+
let hasVerifier = 0;
1513+
1514+
let assemblyFormat = [{
1515+
`(` $cond `,`
1516+
`true` $trueRegion `,`
1517+
`false` $falseRegion
1518+
`)` `:` functional-type(operands, results) attr-dict
1519+
}];
1520+
}
1521+
14651522
//===----------------------------------------------------------------------===//
14661523
// GlobalOp
14671524
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,49 @@ LogicalResult cir::BinOp::verify() {
11871187
return mlir::success();
11881188
}
11891189

1190+
//===----------------------------------------------------------------------===//
1191+
// TernaryOp
1192+
//===----------------------------------------------------------------------===//
1193+
1194+
/// Given the region at `point`, or the parent operation if `point` is None,
1195+
/// return the successor regions. These are the regions that may be selected
1196+
/// during the flow of control. `operands` is a set of optional attributes that
1197+
/// correspond to a constant value for each operand, or null if that operand is
1198+
/// not a constant.
1199+
void cir::TernaryOp::getSuccessorRegions(
1200+
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1201+
// The `true` and the `false` region branch back to the parent operation.
1202+
if (!point.isParent()) {
1203+
regions.push_back(RegionSuccessor(this->getODSResults(0)));
1204+
return;
1205+
}
1206+
1207+
// When branching from the parent operation, both the true and false
1208+
// regions are considered possible successors
1209+
regions.push_back(RegionSuccessor(&getTrueRegion()));
1210+
regions.push_back(RegionSuccessor(&getFalseRegion()));
1211+
}
1212+
1213+
void cir::TernaryOp::build(
1214+
OpBuilder &builder, OperationState &result, Value cond,
1215+
function_ref<void(OpBuilder &, Location)> trueBuilder,
1216+
function_ref<void(OpBuilder &, Location)> falseBuilder) {
1217+
result.addOperands(cond);
1218+
OpBuilder::InsertionGuard guard(builder);
1219+
Region *trueRegion = result.addRegion();
1220+
Block *block = builder.createBlock(trueRegion);
1221+
trueBuilder(builder, result.location);
1222+
Region *falseRegion = result.addRegion();
1223+
builder.createBlock(falseRegion);
1224+
falseBuilder(builder, result.location);
1225+
1226+
auto yield = dyn_cast<YieldOp>(block->getTerminator());
1227+
assert((yield && yield.getNumOperands() <= 1) &&
1228+
"expected zero or one result type");
1229+
if (yield.getNumOperands() == 1)
1230+
result.addTypes(TypeRange{yield.getOperandTypes().front()});
1231+
}
1232+
11901233
//===----------------------------------------------------------------------===//
11911234
// ShiftOp
11921235
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
254254
}
255255
};
256256

257+
class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
258+
public:
259+
using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
260+
261+
mlir::LogicalResult
262+
matchAndRewrite(cir::TernaryOp op,
263+
mlir::PatternRewriter &rewriter) const override {
264+
Location loc = op->getLoc();
265+
Block *condBlock = rewriter.getInsertionBlock();
266+
Block::iterator opPosition = rewriter.getInsertionPoint();
267+
Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
268+
llvm::SmallVector<mlir::Location, 2> locs;
269+
// Ternary result is optional, make sure to populate the location only
270+
// when relevant.
271+
if (op->getResultTypes().size())
272+
locs.push_back(loc);
273+
Block *continueBlock =
274+
rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
275+
rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
276+
277+
Region &trueRegion = op.getTrueRegion();
278+
Block *trueBlock = &trueRegion.front();
279+
mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
280+
rewriter.setInsertionPointToEnd(&trueRegion.back());
281+
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
282+
283+
rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
284+
continueBlock);
285+
rewriter.inlineRegionBefore(trueRegion, continueBlock);
286+
287+
Block *falseBlock = continueBlock;
288+
Region &falseRegion = op.getFalseRegion();
289+
290+
falseBlock = &falseRegion.front();
291+
mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
292+
rewriter.setInsertionPointToEnd(&falseRegion.back());
293+
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
294+
rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
295+
continueBlock);
296+
rewriter.inlineRegionBefore(falseRegion, continueBlock);
297+
298+
rewriter.setInsertionPointToEnd(condBlock);
299+
rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
300+
301+
rewriter.replaceOp(op, continueBlock->getArguments());
302+
303+
// Ok, we're done!
304+
return mlir::success();
305+
}
306+
};
307+
257308
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
258-
patterns
259-
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>(
260-
patterns.getContext());
309+
patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
310+
CIRScopeOpFlattening, CIRTernaryOpFlattening>(
311+
patterns.getContext());
261312
}
262313

263314
void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
269320
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
270321
assert(!cir::MissingFeatures::ifOp());
271322
assert(!cir::MissingFeatures::switchOp());
272-
assert(!cir::MissingFeatures::ternaryOp());
273323
assert(!cir::MissingFeatures::tryOp());
274-
if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
324+
if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
275325
ops.push_back(op);
276326
});
277327

clang/test/CIR/IR/ternary.cir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s | cir-opt | FileCheck %s
2+
!u32i = !cir.int<u, 32>
3+
4+
module {
5+
cir.func @blue(%arg0: !cir.bool) -> !u32i {
6+
%0 = cir.ternary(%arg0, true {
7+
%a = cir.const #cir.int<0> : !u32i
8+
cir.yield %a : !u32i
9+
}, false {
10+
%b = cir.const #cir.int<1> : !u32i
11+
cir.yield %b : !u32i
12+
}) : (!cir.bool) -> !u32i
13+
cir.return %0 : !u32i
14+
}
15+
}
16+
17+
// CHECK: module {
18+
19+
// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
20+
// CHECK: %0 = cir.ternary(%arg0, true {
21+
// CHECK: %1 = cir.const #cir.int<0> : !u32i
22+
// CHECK: cir.yield %1 : !u32i
23+
// CHECK: }, false {
24+
// CHECK: %1 = cir.const #cir.int<1> : !u32i
25+
// CHECK: cir.yield %1 : !u32i
26+
// CHECK: }) : (!cir.bool) -> !u32i
27+
// CHECK: cir.return %0 : !u32i
28+
// CHECK: }
29+
30+
// CHECK: }

clang/test/CIR/Lowering/ternary.cir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
2+
// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
3+
4+
!u32i = !cir.int<u, 32>
5+
6+
module {
7+
cir.func @blue(%arg0: !cir.bool) -> !u32i {
8+
%0 = cir.ternary(%arg0, true {
9+
%a = cir.const #cir.int<0> : !u32i
10+
cir.yield %a : !u32i
11+
}, false {
12+
%b = cir.const #cir.int<1> : !u32i
13+
cir.yield %b : !u32i
14+
}) : (!cir.bool) -> !u32i
15+
cir.return %0 : !u32i
16+
}
17+
}
18+
19+
// LLVM-LABEL: define i32 {{.*}}@blue(
20+
// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
21+
// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]]
22+
// LLVM: [[B1]]:
23+
// LLVM: br label %[[M:[[:alnum:]]+]]
24+
// LLVM: [[B2]]:
25+
// LLVM: br label %[[M]]
26+
// LLVM: [[M]]:
27+
// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
28+
// LLVM: br label %[[B3:[[:alnum:]]+]]
29+
// LLVM: [[B3]]:
30+
// LLVM: ret i32 [[R]]

clang/test/CIR/Transforms/ternary.cir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @foo(%arg0: !s32i) -> !s32i {
7+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
8+
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
9+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
10+
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
11+
%3 = cir.const #cir.int<0> : !s32i
12+
%4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
13+
%5 = cir.ternary(%4, true {
14+
%7 = cir.const #cir.int<3> : !s32i
15+
cir.yield %7 : !s32i
16+
}, false {
17+
%7 = cir.const #cir.int<5> : !s32i
18+
cir.yield %7 : !s32i
19+
}) : (!cir.bool) -> !s32i
20+
cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
21+
%6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
22+
cir.return %6 : !s32i
23+
}
24+
25+
// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
26+
// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
27+
// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
28+
// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
29+
// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
30+
// CHECK: %3 = cir.const #cir.int<0> : !s32i
31+
// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
32+
// CHECK: cir.brcond %4 ^bb1, ^bb2
33+
// CHECK: ^bb1: // pred: ^bb0
34+
// CHECK: %5 = cir.const #cir.int<3> : !s32i
35+
// CHECK: cir.br ^bb3(%5 : !s32i)
36+
// CHECK: ^bb2: // pred: ^bb0
37+
// CHECK: %6 = cir.const #cir.int<5> : !s32i
38+
// CHECK: cir.br ^bb3(%6 : !s32i)
39+
// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2
40+
// CHECK: cir.br ^bb4
41+
// CHECK: ^bb4: // pred: ^bb3
42+
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
43+
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
44+
// CHECK: cir.return %8 : !s32i
45+
// CHECK: }
46+
47+
cir.func @foo2(%arg0: !cir.bool) {
48+
cir.ternary(%arg0, true {
49+
cir.yield
50+
}, false {
51+
cir.yield
52+
}) : (!cir.bool) -> ()
53+
cir.return
54+
}
55+
56+
// CHECK: cir.func @foo2(%arg0: !cir.bool) {
57+
// CHECK: cir.brcond %arg0 ^bb1, ^bb2
58+
// CHECK: ^bb1: // pred: ^bb0
59+
// CHECK: cir.br ^bb3
60+
// CHECK: ^bb2: // pred: ^bb0
61+
// CHECK: cir.br ^bb3
62+
// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2
63+
// CHECK: cir.br ^bb4
64+
// CHECK: ^bb4: // pred: ^bb3
65+
// CHECK: cir.return
66+
// CHECK: }
67+
68+
}

0 commit comments

Comments
 (0)