Skip to content

[mlir][flang] Added Weighted[Region]BranchOpInterface's. #142079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

vzakhari
Copy link
Contributor

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.
@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:cf flang:fir-hlfir labels May 30, 2025
@llvmbot
Copy link
Member

llvmbot commented May 30, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir

Author: Slava Zakharin (vzakhari)

Changes

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.


Patch is 23.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142079.diff

14 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+14-4)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+20-1)
  • (modified) flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp (+3-1)
  • (added) flang/test/Fir/cfg-conversion-if.fir (+46)
  • (modified) flang/test/Fir/fir-ops.fir (+16)
  • (modified) flang/test/Fir/invalid.fir (+37)
  • (modified) mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td (+19-15)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+20)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+107)
  • (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+5-1)
  • (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+49)
  • (modified) mlir/test/Conversion/ControlFlowToLLVM/branch.mlir (+14)
  • (modified) mlir/test/Dialect/ControlFlow/invalid.mlir (+36)
  • (modified) mlir/test/Dialect/ControlFlow/ops.mlir (+10)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index f4b17ef7eed09..7001e25a9bcda 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
   }];
 }
 
-def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
-    "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
-    NoRegionArguments]> {
+def fir_IfOp
+    : region_Op<
+          "if", [DeclareOpInterfaceMethods<
+                     RegionBranchOpInterface, ["getRegionInvocationBounds",
+                                               "getEntrySuccessorRegions"]>,
+                 RecursiveMemoryEffects, NoRegionArguments,
+                 WeightedRegionBranchOpInterface]> {
   let summary = "if-then-else conditional operation";
   let description = [{
     Used to conditionally execute operations. This operation is the FIR
@@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
     ```
   }];
 
-  let arguments = (ins I1:$condition);
+  let arguments = (ins I1:$condition,
+      OptionalAttr<DenseI32ArrayAttr>:$region_weights);
   let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region
@@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
 
     void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
                            unsigned resultNum);
+
+    /// Returns the display name string for the region_weights attribute.
+    static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
+      return "weights";
+    }
   }];
 }
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index cbe93907265f6..2949120894132 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
       parser.resolveOperand(cond, i1Type, result.operands))
     return mlir::failure();
 
+  if (mlir::succeeded(
+          parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
+    if (parser.parseLParen())
+      return mlir::failure();
+    mlir::DenseI32ArrayAttr weights;
+    if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
+      return mlir::failure();
+    if (weights)
+      result.addAttribute(getRegionWeightsAttrName(result.name), weights);
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
   if (parser.parseOptionalArrowTypeList(result.types))
     return mlir::failure();
 
@@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
 void fir::IfOp::print(mlir::OpAsmPrinter &p) {
   bool printBlockTerminators = false;
   p << ' ' << getCondition();
+  if (auto weights = getRegionWeightsAttr()) {
+    p << ' ' << getWeightsAttrAssemblyName() << '(';
+    p.printStrippedAttrOrType(weights);
+    p << ')';
+  }
   if (!getResults().empty()) {
     p << " -> (" << getResultTypes() << ')';
     printBlockTerminators = true;
@@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
     p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
                   printBlockTerminators);
   }
-  p.printOptionalAttrDict((*this)->getAttrs());
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          /*elideAttrs=*/{getRegionWeightsAttrName()});
 }
 
 void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index 8a9e9b80134b8..5256ef8d53d85 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -212,9 +212,11 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
     }
 
     rewriter.setInsertionPointToEnd(condBlock);
-    rewriter.create<mlir::cf::CondBranchOp>(
+    auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
         loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
+    if (auto weights = ifOp.getRegionWeightsOrNull())
+      branchOp.setBranchWeights(weights);
     rewriter.replaceOp(ifOp, continueBlock->getArguments());
     return success();
   }
diff --git a/flang/test/Fir/cfg-conversion-if.fir b/flang/test/Fir/cfg-conversion-if.fir
new file mode 100644
index 0000000000000..1e30ee8e64f02
--- /dev/null
+++ b/flang/test/Fir/cfg-conversion-if.fir
@@ -0,0 +1,46 @@
+// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
+
+func.func private @callee() -> none
+
+// CHECK-LABEL:   func.func @if_then(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) {
+// CHECK:           cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_0:.*]] = fir.call @callee() : () -> none
+// CHECK:           cf.br ^bb2
+// CHECK:         ^bb2:
+// CHECK:           return
+// CHECK:         }
+func.func @if_then(%cond: i1) {
+  fir.if %cond weights([10, 90]) {
+    fir.call @callee() : () -> none
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @if_then_else(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) -> i32 {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           cf.br ^bb3(%[[VAL_0]] : i32)
+// CHECK:         ^bb2:
+// CHECK:           cf.br ^bb3(%[[VAL_1]] : i32)
+// CHECK:         ^bb3(%[[VAL_2:.*]]: i32):
+// CHECK:           cf.br ^bb4
+// CHECK:         ^bb4:
+// CHECK:           return %[[VAL_2]] : i32
+// CHECK:         }
+func.func @if_then_else(%cond: i1) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %result = fir.if %cond weights([90, 10]) -> i32 {
+    fir.result %c0 : i32
+  } else {
+    fir.result %c1 : i32
+  }
+  return %result : i32
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 9c444d2f4e0bc..3585bf9efca3e 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
   %6 = arith.addi %2, %5 : index
   return %6 : index
 }
+
+// CHECK-LABEL:   func.func @test_if_weights(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) {
+func.func @test_if_weights(%cond: i1) {
+// CHECK:           fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK:           }
+  fir.if %cond weights([99, 1]) {
+  }
+// CHECK:           fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK:           } else {
+// CHECK:           }
+  fir.if %cond weights ([99,1]) {
+  } else {
+  }
+  return
+}
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index fd607fd9066f7..0391cdbef71e5 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1385,3 +1385,40 @@ fir.local {type = local_init} @x.localizer : f32 init {
 ^bb0(%arg0: f32, %arg1: f32):
   fir.yield(%arg0 : f32)
 }
+
+// -----
+
+func.func @wrong_weights_number_in_if_then(%cond: i1) {
+// expected-error @below {{number of weights (1) does not match the number of regions (2)}}
+  fir.if %cond weights([50]) {
+  }
+  return
+}
+
+// -----
+
+func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
+// expected-error @below {{number of weights (3) does not match the number of regions (2)}}
+  fir.if %cond weights([50, 40, 10]) {
+  } else {
+  }
+  return
+}
+
+// -----
+
+func.func @negative_weight_in_if_then(%cond: i1) {
+// expected-error @below {{weight #0 must be non-negative}}
+  fir.if %cond weights([-1, 101]) {
+  }
+  return
+}
+
+// -----
+
+func.func @wrong_total_weight_in_if_then(%cond: i1) {
+// expected-error @below {{total weight 101 is not 100}}
+  fir.if %cond weights([1, 100]) {
+  }
+  return
+}
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index 48f12b46a57f1..79da81ba049dd 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
 // CondBranchOp
 //===----------------------------------------------------------------------===//
 
-def CondBranchOp : CF_Op<"cond_br",
-    [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
-     Pure, Terminator]> {
+def CondBranchOp
+    : CF_Op<"cond_br", [AttrSizedOperandSegments,
+                        DeclareOpInterfaceMethods<
+                            BranchOpInterface, ["getSuccessorForOperands"]>,
+                        WeightedBranchOpInterface, Pure, Terminator]> {
   let summary = "Conditional branch operation";
   let description = [{
     The `cf.cond_br` terminator operation represents a conditional branch on a
@@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
     ```
   }];
 
-  let arguments = (ins I1:$condition,
-                       Variadic<AnyType>:$trueDestOperands,
-                       Variadic<AnyType>:$falseDestOperands);
+  let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
+      Variadic<AnyType>:$falseDestOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
 
-  let builders = [
-    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-      "ValueRange":$trueOperands, "Block *":$falseDest,
-      "ValueRange":$falseOperands), [{
-      build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
+  let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+                                "ValueRange":$trueOperands,
+                                "Block *":$falseDest,
+                                "ValueRange":$falseOperands),
+                            [{
+      build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
             falseDest);
     }]>,
-    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-      "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
+                  OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+                                "Block *":$falseDest,
+                                CArg<"ValueRange", "{}">:$falseOperands),
+                            [{
       build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
             falseOperands);
     }]>];
@@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
 
   let hasCanonicalizer = 1;
   let assemblyFormat = [{
-    $condition `,`
+    $condition (`weights` `(` $branch_weights^ `)` )? `,`
     $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
     $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
     attr-dict
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 7f6967f11444f..d63800c12d132 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
                                             const SuccessorOperands &operands);
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the branch weights attached to an operation
+/// implementing WeightedBranchOpInterface are correct.
+LogicalResult verifyBranchWeights(Operation *op);
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// WeightedRegiobBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the region weights attached to an operation
+/// implementing WeightedRegiobBranchOpInterface are correct.
+LogicalResult verifyRegionBranchWeights(Operation *op);
+} // namespace detail
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 69bce78e946c8..7a47b686ac7d1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -375,6 +375,113 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
+  let description = [{
+    This interface provides weight information for branching terminator
+    operations, i.e. terminator operations with successors.
+
+    This interface provides methods for getting/setting integer non-negative
+    weight of each branch in the range from 0 to 100. The sum of weights
+    must be 100. The number of weights must match the number of successors
+    of the operation.
+
+    The weights specify the probability (in percents) of taking
+    a particular branch.
+
+    The default implementations of the methods expect the operation
+    to have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [InterfaceMethod<
+                     /*desc=*/"Returns the branch weights attribute or nullptr",
+                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
+                     /*methodName=*/"getBranchWeightsOrNull",
+                     /*args=*/(ins),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getBranchWeightsAttr();
+      }]>,
+                 InterfaceMethod<
+                     /*desc=*/"Sets the branch weights attribute",
+                     /*returnType=*/"void",
+                     /*methodName=*/"setBranchWeights",
+                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setBranchWeightsAttr(attr);
+      }]>,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifyBranchWeights($_op);
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+// TODO: the probabilities of entering a particular region seem
+// to correlate with the values returned by
+// RegionBranchOpInterface::invocationBounds(), and we should probably
+// verify that the values are consistent. In that case, should
+// WeightedRegionBranchOpInterface extend RegionBranchOpInterface?
+def WeightedRegionBranchOpInterface
+    : OpInterface<"WeightedRegionBranchOpInterface"> {
+  let description = [{
+    This interface provides weight information for region operations
+    that exhibit branching behavior between held regions.
+
+    This interface provides methods for getting/setting integer non-negative
+    weight of each branch in the range from 0 to 100. The sum of weights
+    must be 100. The number of weights must match the number of regions
+    held by the operation (including empty regions).
+
+    The weights specify the probability (in percents) of branching
+    to a particular region when first executing the operation.
+    For example, for loop-like operations with a single region
+    the weight specifies the probability of entering the loop.
+    In this case, the weight must be either 0 or 100.
+
+    The default implementations of the methods expect the operation
+    to have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [InterfaceMethod<
+                     /*desc=*/"Returns the region weights attribute or nullptr",
+                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
+                     /*methodName=*/"getRegionWeightsOrNull",
+                     /*args=*/(ins),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRegionWeightsAttr();
+      }]>,
+                 InterfaceMethod<
+                     /*desc=*/"Sets the region weights attribute",
+                     /*returnType=*/"void",
+                     /*methodName=*/"setRegionWeights",
+                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setRegionWeightsAttr(attr);
+      }]>,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifyRegionBranchWeights($_op);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..12769e486a3c7 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -166,10 +166,14 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
                           TypeRange(adaptor.getFalseDestOperands()));
     if (failed(convertedFalseBlock))
       return failure();
-    Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
+    auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
         op, adaptor.getCondition(), *convertedTrueBlock,
         adaptor.getTrueDestOperands(), *convertedFalseBlock,
         adaptor.getFalseDestOperands());
+    if (auto weights = op.getBranchWeightsOrNull()) {
+      newOp.setBranchWeights(weights);
+      op.removeBranchWeightsAttr();
+    }
     // TODO: We should not just forward all attributes like that. But there are
     // existing Flang tests that depend on this behavior.
     newOp->setAttrs(op->getAttrDictionary());
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2ae334b517a31..e587e8f1af178 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -80,6 +80,55 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult detail::verifyBranchWeights(Operation *op) {
+  auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
+  if (weights) {
+    if (weights.size() != op->getNumSuccessors())
+      return op->emitError() << "number of weights (" << weights.size()
+                             << ") does not match the number of successors ("
+                             << op->getNumSuccessors() << ")";
+    int32_t total = 0;
+    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
+      if (weight.value() < 0)
+        return op->emitError()
+               << "weight #" << weight.index() << " must be non-negative";
+      total += weight.value();
+    }
+    if (total != 100)
+      return op->emitError() << "total weight " << total << " is not 100";
+  }
+  return mlir::success();
+}
+
+//===--------------------...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

LGTM but wait for an MLIR reviewer too

// WeightedRegionBranchOpInterface
//===----------------------------------------------------------------------===//

LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Both of these could share a helper

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense to me. Are you planning on replacing the LLVM BranchWeightOpInterface with the new generic interface?

@vzakhari
Copy link
Contributor Author

vzakhari commented Jun 2, 2025

Thanks, makes sense to me. Are you planning on replacing the LLVM BranchWeightOpInterface with the new generic interface?

Good question. I think it is worth replacing LLVM dialect's BranchWeightOpInterface with this one, unless there are any objections to this.

@vzakhari vzakhari requested a review from gysit June 3, 2025 00:01
@vzakhari
Copy link
Contributor Author

vzakhari commented Jun 3, 2025

I replaced the LLVM dialect's BranchWeightOpInterface with the generic WeightedBranchOpInterface. The new interface should allow what LLVM's interface allowed (e.g. that the weights do not total to 100, and that direct calls may have weight attached to them).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:cf mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants