Skip to content

[mlir][flang] Added Weighted[Region]BranchOpInterface's. #142079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
}];
}

def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
NoRegionArguments]> {
def fir_IfOp
: region_Op<
"if", [DeclareOpInterfaceMethods<
RegionBranchOpInterface, ["getRegionInvocationBounds",
"getEntrySuccessorRegions"]>,
RecursiveMemoryEffects, NoRegionArguments,
WeightedRegionBranchOpInterface]> {
let summary = "if-then-else conditional operation";
let description = [{
Used to conditionally execute operations. This operation is the FIR
Expand All @@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
```
}];

let arguments = (ins I1:$condition);
let arguments = (ins I1:$condition,
OptionalAttr<DenseI32ArrayAttr>:$region_weights);
let results = (outs Variadic<AnyType>:$results);

let regions = (region
Expand Down Expand Up @@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac

void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
unsigned resultNum);

/// Returns the display name string for the region_weights attribute.
static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
return "weights";
}
}];
}

Expand Down
21 changes: 20 additions & 1 deletion flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
parser.resolveOperand(cond, i1Type, result.operands))
return mlir::failure();

if (mlir::succeeded(
parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
if (parser.parseLParen())
return mlir::failure();
mlir::DenseI32ArrayAttr weights;
if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
return mlir::failure();
if (weights)
result.addAttribute(getRegionWeightsAttrName(result.name), weights);
if (parser.parseRParen())
return mlir::failure();
}

if (parser.parseOptionalArrowTypeList(result.types))
return mlir::failure();

Expand Down Expand Up @@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
void fir::IfOp::print(mlir::OpAsmPrinter &p) {
bool printBlockTerminators = false;
p << ' ' << getCondition();
if (auto weights = getRegionWeightsAttr()) {
p << ' ' << getWeightsAttrAssemblyName() << '(';
p.printStrippedAttrOrType(weights);
p << ')';
}
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ')';
printBlockTerminators = true;
Expand All @@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
printBlockTerminators);
}
p.printOptionalAttrDict((*this)->getAttrs());
p.printOptionalAttrDict((*this)->getAttrs(),
/*elideAttrs=*/{getRegionWeightsAttrName()});
}

void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
Expand Down
4 changes: 3 additions & 1 deletion flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,11 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
}

rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<mlir::cf::CondBranchOp>(
auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
if (auto weights = ifOp.getRegionWeightsOrNull())
branchOp.setBranchWeights(weights);
rewriter.replaceOp(ifOp, continueBlock->getArguments());
return success();
}
Expand Down
46 changes: 46 additions & 0 deletions flang/test/Fir/cfg-conversion-if.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s

func.func private @callee() -> none

// CHECK-LABEL: func.func @if_then(
// CHECK-SAME: %[[ARG0:.*]]: i1) {
// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none
// CHECK: cf.br ^bb2
// CHECK: ^bb2:
// CHECK: return
// CHECK: }
func.func @if_then(%cond: i1) {
fir.if %cond weights([10, 90]) {
fir.call @callee() : () -> none
}
return
}

// -----

// CHECK-LABEL: func.func @if_then_else(
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: cf.br ^bb3(%[[VAL_0]] : i32)
// CHECK: ^bb2:
// CHECK: cf.br ^bb3(%[[VAL_1]] : i32)
// CHECK: ^bb3(%[[VAL_2:.*]]: i32):
// CHECK: cf.br ^bb4
// CHECK: ^bb4:
// CHECK: return %[[VAL_2]] : i32
// CHECK: }
func.func @if_then_else(%cond: i1) -> i32 {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%result = fir.if %cond weights([90, 10]) -> i32 {
fir.result %c0 : i32
} else {
fir.result %c1 : i32
}
return %result : i32
}
16 changes: 16 additions & 0 deletions flang/test/Fir/fir-ops.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
%6 = arith.addi %2, %5 : index
return %6 : index
}

// CHECK-LABEL: func.func @test_if_weights(
// CHECK-SAME: %[[ARG0:.*]]: i1) {
func.func @test_if_weights(%cond: i1) {
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
// CHECK: }
fir.if %cond weights([99, 1]) {
}
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
// CHECK: } else {
// CHECK: }
fir.if %cond weights ([99,1]) {
} else {
}
return
}
28 changes: 28 additions & 0 deletions flang/test/Fir/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1385,3 +1385,31 @@ fir.local {type = local_init} @x.localizer : f32 init {
^bb0(%arg0: f32, %arg1: f32):
fir.yield(%arg0 : f32)
}

// -----

func.func @wrong_weights_number_in_if_then(%cond: i1) {
// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
fir.if %cond weights([50]) {
}
return
}

// -----

func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
fir.if %cond weights([50, 40, 10]) {
} else {
}
return
}

// -----

func.func @negative_weight_in_if_then(%cond: i1) {
// expected-error @below {{weight #0 must be non-negative}}
fir.if %cond weights([-1, 101]) {
}
return
}
34 changes: 19 additions & 15 deletions mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
// CondBranchOp
//===----------------------------------------------------------------------===//

def CondBranchOp : CF_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator]> {
def CondBranchOp
: CF_Op<"cond_br", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<
BranchOpInterface, ["getSuccessorForOperands"]>,
WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
The `cf.cond_br` terminator operation represents a conditional branch on a
Expand Down Expand Up @@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
```
}];

let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);

let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands,
"Block *":$falseDest,
"ValueRange":$falseOperands),
[{
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest,
CArg<"ValueRange", "{}">:$falseOperands),
[{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
Expand Down Expand Up @@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",

let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$condition (`weights` `(` $branch_weights^ `)` )? `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
Expand Down
36 changes: 0 additions & 36 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
];
}

def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
provides setters and getters for the operation's branch weights attribute.
The default implementation of the interface methods expect the operation to
have an attribute of type DenseI32ArrayAttr named branch_weights.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns the branch weights attribute or nullptr",
/*returnType=*/ "::mlir::DenseI32ArrayAttr",
/*methodName=*/ "getBranchWeightsOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getBranchWeightsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the branch weights attribute",
/*returnType=*/ "void",
/*methodName=*/ "setBranchWeights",
/*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setBranchWeightsAttr(attr);
}]
>
];
}

def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
let description = [{
An interface for memory operations that can carry access groups metadata.
Expand Down
47 changes: 25 additions & 22 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;

// Call-related operations.
def LLVM_InvokeOp : LLVM_Op<"invoke", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Terminator]> {
def LLVM_InvokeOp
: LLVM_Op<"invoke", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
Terminator]> {
let arguments = (ins
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Expand Down Expand Up @@ -734,12 +734,13 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
// CallOp
//===----------------------------------------------------------------------===//

def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
def LLVM_CallOp
: LLVM_MemAccessOpBase<
"call", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
Expand Down Expand Up @@ -1047,11 +1048,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
LLVM_TerminatorPassthroughOpBuilder
];
}
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
def LLVM_CondBrOp
: LLVM_TerminatorOp<
"cond_br", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
Pure]> {
let arguments = (ins I1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
Variadic<LLVM_Type>:$falseDestOperands,
Expand Down Expand Up @@ -1136,11 +1138,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
}];
}

def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
def LLVM_SwitchOp
: LLVM_TerminatorOp<
"switch", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
Pure]> {
let arguments = (ins
AnySignlessInteger:$value,
Variadic<AnyType>:$defaultOperands,
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands);
} // namespace detail

//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//

namespace detail {
/// Verify that the branch weights attached to an operation
/// implementing WeightedBranchOpInterface are correct.
LogicalResult verifyBranchWeights(Operation *op);
} // namespace detail

//===----------------------------------------------------------------------===//
// WeightedRegiobBranchOpInterface
//===----------------------------------------------------------------------===//

namespace detail {
/// Verify that the region weights attached to an operation
/// implementing WeightedRegiobBranchOpInterface are correct.
LogicalResult verifyRegionBranchWeights(Operation *op);
} // namespace detail

//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading