Skip to content

Commit ec44c74

Browse files
[CIR] Upstream support for FlattenCFG switch and SwitchFlatOp (#139154)
This PR adds support for the `FlattenCFG` transformation on `switch` statements. It also introduces the `SwitchFlatOp`, which is necessary for subsequent lowering to LLVM.
1 parent ea4bf34 commit ec44c74

File tree

7 files changed

+852
-6
lines changed

7 files changed

+852
-6
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
971971
}];
972972
}
973973

974+
//===----------------------------------------------------------------------===//
975+
// SwitchFlatOp
976+
//===----------------------------------------------------------------------===//
977+
978+
def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
979+
Terminator]> {
980+
981+
let description = [{
982+
The `cir.switch.flat` operation is a region-less and simplified
983+
version of the `cir.switch`.
984+
Its representation is closer to LLVM IR dialect
985+
than the C/C++ language feature.
986+
}];
987+
988+
let arguments = (ins
989+
CIR_IntType:$condition,
990+
Variadic<AnyType>:$defaultOperands,
991+
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
992+
ArrayAttr:$caseValues,
993+
DenseI32ArrayAttr:$case_operand_segments
994+
);
995+
996+
let successors = (successor
997+
AnySuccessor:$defaultDestination,
998+
VariadicSuccessor<AnySuccessor>:$caseDestinations
999+
);
1000+
1001+
let assemblyFormat = [{
1002+
$condition `:` type($condition) `,`
1003+
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
1004+
custom<SwitchFlatOpCases>(ref(type($condition)), $caseValues,
1005+
$caseDestinations, $caseOperands,
1006+
type($caseOperands))
1007+
attr-dict
1008+
}];
1009+
1010+
let builders = [
1011+
OpBuilder<(ins "mlir::Value":$condition,
1012+
"mlir::Block *":$defaultDestination,
1013+
"mlir::ValueRange":$defaultOperands,
1014+
CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
1015+
CArg<"mlir::BlockRange", "{}">:$caseDestinations,
1016+
CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
1017+
];
1018+
}
1019+
9741020
//===----------------------------------------------------------------------===//
9751021
// BrOp
9761022
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
2323
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
2424
#include "clang/CIR/MissingFeatures.h"
25+
#include <numeric>
2526

2627
using namespace mlir;
2728
using namespace cir;
@@ -962,6 +963,101 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
962963
});
963964
}
964965

966+
//===----------------------------------------------------------------------===//
967+
// SwitchFlatOp
968+
//===----------------------------------------------------------------------===//
969+
970+
void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
971+
Value value, Block *defaultDestination,
972+
ValueRange defaultOperands,
973+
ArrayRef<APInt> caseValues,
974+
BlockRange caseDestinations,
975+
ArrayRef<ValueRange> caseOperands) {
976+
977+
std::vector<mlir::Attribute> caseValuesAttrs;
978+
for (const APInt &val : caseValues)
979+
caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
980+
mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
981+
982+
build(builder, result, value, defaultOperands, caseOperands, attrs,
983+
defaultDestination, caseDestinations);
984+
}
985+
986+
/// <cases> ::= `[` (case (`,` case )* )? `]`
987+
/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
988+
static ParseResult parseSwitchFlatOpCases(
989+
OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
990+
SmallVectorImpl<Block *> &caseDestinations,
991+
SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
992+
&caseOperands,
993+
SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
994+
if (failed(parser.parseLSquare()))
995+
return failure();
996+
if (succeeded(parser.parseOptionalRSquare()))
997+
return success();
998+
llvm::SmallVector<mlir::Attribute> values;
999+
1000+
auto parseCase = [&]() {
1001+
int64_t value = 0;
1002+
if (failed(parser.parseInteger(value)))
1003+
return failure();
1004+
1005+
values.push_back(cir::IntAttr::get(flagType, value));
1006+
1007+
Block *destination;
1008+
llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
1009+
llvm::SmallVector<Type> operandTypes;
1010+
if (parser.parseColon() || parser.parseSuccessor(destination))
1011+
return failure();
1012+
if (!parser.parseOptionalLParen()) {
1013+
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1014+
/*allowResultNumber=*/false) ||
1015+
parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1016+
return failure();
1017+
}
1018+
caseDestinations.push_back(destination);
1019+
caseOperands.emplace_back(operands);
1020+
caseOperandTypes.emplace_back(operandTypes);
1021+
return success();
1022+
};
1023+
if (failed(parser.parseCommaSeparatedList(parseCase)))
1024+
return failure();
1025+
1026+
caseValues = ArrayAttr::get(flagType.getContext(), values);
1027+
1028+
return parser.parseRSquare();
1029+
}
1030+
1031+
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1032+
Type flagType, mlir::ArrayAttr caseValues,
1033+
SuccessorRange caseDestinations,
1034+
OperandRangeRange caseOperands,
1035+
const TypeRangeRange &caseOperandTypes) {
1036+
p << '[';
1037+
p.printNewline();
1038+
if (!caseValues) {
1039+
p << ']';
1040+
return;
1041+
}
1042+
1043+
size_t index = 0;
1044+
llvm::interleave(
1045+
llvm::zip(caseValues, caseDestinations),
1046+
[&](auto i) {
1047+
p << " ";
1048+
mlir::Attribute a = std::get<0>(i);
1049+
p << mlir::cast<cir::IntAttr>(a).getValue();
1050+
p << ": ";
1051+
p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1052+
},
1053+
[&] {
1054+
p << ',';
1055+
p.printNewline();
1056+
});
1057+
p.printNewline();
1058+
p << ']';
1059+
}
1060+
9651061
//===----------------------------------------------------------------------===//
9661062
// GlobalOp
9671063
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
8484
}
8585
};
8686

87+
struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
88+
using OpRewritePattern<SwitchOp>::OpRewritePattern;
89+
90+
LogicalResult matchAndRewrite(SwitchOp op,
91+
PatternRewriter &rewriter) const final {
92+
if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
93+
return failure();
94+
95+
rewriter.eraseOp(op);
96+
return success();
97+
}
98+
};
99+
87100
//===----------------------------------------------------------------------===//
88101
// CIRCanonicalizePass
89102
//===----------------------------------------------------------------------===//
@@ -127,8 +140,8 @@ void CIRCanonicalizePass::runOnOperation() {
127140
assert(!cir::MissingFeatures::callOp());
128141
// CastOp, UnaryOp and VecExtractOp are here to perform a manual `fold` in
129142
// applyOpPatternsGreedily.
130-
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp, VecExtractOp>(
131-
op))
143+
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
144+
VecExtractOp>(op))
132145
ops.push_back(op);
133146
});
134147

0 commit comments

Comments
 (0)