9
9
#include " mlir/Transforms/RegionUtils.h"
10
10
#include " mlir/Analysis/TopologicalSortUtils.h"
11
11
#include " mlir/IR/Block.h"
12
+ #include " mlir/IR/BuiltinOps.h"
12
13
#include " mlir/IR/IRMapping.h"
13
14
#include " mlir/IR/Operation.h"
14
15
#include " mlir/IR/PatternMatch.h"
15
16
#include " mlir/IR/RegionGraphTraits.h"
16
17
#include " mlir/IR/Value.h"
17
18
#include " mlir/Interfaces/ControlFlowInterfaces.h"
18
19
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
+ #include " mlir/Support/LogicalResult.h"
19
21
20
22
#include " llvm/ADT/DepthFirstIterator.h"
21
23
#include " llvm/ADT/PostOrderIterator.h"
24
+ #include " llvm/ADT/STLExtras.h"
25
+ #include " llvm/ADT/SmallSet.h"
22
26
23
27
#include < deque>
28
+ #include < iterator>
24
29
25
30
using namespace mlir ;
26
31
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
699
704
blockIterators.push_back (mergeBlock->begin ());
700
705
701
706
// Update each of the predecessor terminators with the new arguments.
702
- SmallVector<SmallVector<Value, 8 >, 2 > newArguments (
703
- 1 + blocksToMerge.size (),
704
- SmallVector<Value, 8 >(operandsToMerge.size ()));
707
+ SmallVector<SmallVector<Value, 8 >, 2 > newArguments (1 + blocksToMerge.size (),
708
+ SmallVector<Value, 8 >());
705
709
unsigned curOpIndex = 0 ;
706
710
for (const auto &it : llvm::enumerate (operandsToMerge)) {
707
711
unsigned nextOpOffset = it.value ().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
712
716
Block::iterator &blockIter = blockIterators[i];
713
717
std::advance (blockIter, nextOpOffset);
714
718
auto &operand = blockIter->getOpOperand (it.value ().second );
715
- newArguments[i][it.index ()] = operand.get ();
716
-
717
- // Update the operand and insert an argument if this is the leader.
718
- if (i == 0 ) {
719
- Value operandVal = operand.get ();
720
- operand.set (leaderBlock->addArgument (operandVal.getType (),
721
- operandVal.getLoc ()));
719
+ Value operandVal = operand.get ();
720
+ Value *it = std::find (newArguments[i].begin (), newArguments[i].end (),
721
+ operandVal);
722
+ if (it == newArguments[i].end ()) {
723
+ newArguments[i].push_back (operandVal);
724
+ // Update the operand and insert an argument if this is the leader.
725
+ if (i == 0 ) {
726
+ operand.set (leaderBlock->addArgument (operandVal.getType (),
727
+ operandVal.getLoc ()));
728
+ }
729
+ } else if (i == 0 ) {
730
+ // If this is the leader, update the operand but do not insert a new
731
+ // argument. Instead, the opearand should point to one of the
732
+ // arguments we already passed (and that contained `operandVal`)
733
+ operand.set (leaderBlock->getArgument (
734
+ std::distance (newArguments[i].begin (), it)));
722
735
}
723
736
}
724
737
}
@@ -818,6 +831,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818
831
return success (anyChanged);
819
832
}
820
833
834
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
835
+ Block &block) {
836
+ SmallVector<size_t > argsToErase;
837
+
838
+ // Go through the arguments of the block
839
+ for (size_t argIdx = 0 ; argIdx < block.getNumArguments (); argIdx++) {
840
+ bool sameArg = true ;
841
+ Value commonValue;
842
+
843
+ // Go through the block predecessor and flag if they pass to the block
844
+ // different values for the same argument
845
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
846
+ predIt != predE; ++predIt) {
847
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
848
+ if (!branch) {
849
+ sameArg = false ;
850
+ break ;
851
+ }
852
+ unsigned succIndex = predIt.getSuccessorIndex ();
853
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
854
+ auto operands = succOperands.getForwardedOperands ();
855
+ if (!commonValue) {
856
+ commonValue = operands[argIdx];
857
+ } else {
858
+ if (operands[argIdx] != commonValue) {
859
+ sameArg = false ;
860
+ break ;
861
+ }
862
+ }
863
+ }
864
+
865
+ // If they are passing the same value, drop the argument
866
+ if (commonValue && sameArg) {
867
+ argsToErase.push_back (argIdx);
868
+
869
+ // Remove the argument from the block
870
+ Value argVal = block.getArgument (argIdx);
871
+ rewriter.replaceAllUsesWith (argVal, commonValue);
872
+ }
873
+ }
874
+
875
+ // Remove the arguments
876
+ for (auto argIdx : llvm::reverse (argsToErase)) {
877
+ block.eraseArgument (argIdx);
878
+
879
+ // Remove the argument from the branch ops
880
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
881
+ predIt != predE; ++predIt) {
882
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
883
+ unsigned succIndex = predIt.getSuccessorIndex ();
884
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
885
+ succOperands.erase (argIdx);
886
+ }
887
+ }
888
+ return success (!argsToErase.empty ());
889
+ }
890
+
891
+ // / This optimization drops redundant argument to blocks. I.e., if a given
892
+ // / argument to a block receives the same value from each of the block
893
+ // / predecessors, we can remove the argument from the block and use directly the
894
+ // / original value. This is a simple example:
895
+ // /
896
+ // / %cond = llvm.call @rand() : () -> i1
897
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
898
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
899
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
900
+ // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901
+ // / : i64)
902
+ // /
903
+ // / ^bb1(%arg0 : i64, %arg1 : i64):
904
+ // / llvm.call @foo(%arg0, %arg1)
905
+ // /
906
+ // / The previous IR can be rewritten as:
907
+ // / %cond = llvm.call @rand() : () -> i1
908
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
909
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
910
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
911
+ // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
912
+ // /
913
+ // / ^bb1(%arg0 : i64):
914
+ // / llvm.call @foo(%val0, %arg0)
915
+ // /
916
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
917
+ MutableArrayRef<Region> regions) {
918
+ llvm::SmallSetVector<Region *, 1 > worklist;
919
+ for (auto ®ion : regions)
920
+ worklist.insert (®ion);
921
+ bool anyChanged = false ;
922
+ while (!worklist.empty ()) {
923
+ Region *region = worklist.pop_back_val ();
924
+
925
+ // Add any nested regions to the worklist.
926
+ for (Block &block : *region) {
927
+ anyChanged = succeeded (dropRedundantArguments (rewriter, block));
928
+
929
+ for (auto &op : block)
930
+ for (auto &nestedRegion : op.getRegions ())
931
+ worklist.insert (&nestedRegion);
932
+ }
933
+ }
934
+ return success (anyChanged);
935
+ }
936
+
821
937
// ===----------------------------------------------------------------------===//
822
938
// Region Simplification
823
939
// ===----------------------------------------------------------------------===//
@@ -832,8 +948,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
832
948
bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
833
949
bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
834
950
bool mergedIdenticalBlocks = false ;
835
- if (mergeBlocks)
951
+ bool droppedRedundantArguments = false ;
952
+ if (mergeBlocks) {
836
953
mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
954
+ droppedRedundantArguments =
955
+ succeeded (dropRedundantArguments (rewriter, regions));
956
+ }
837
957
return success (eliminatedBlocks || eliminatedOpsOrArgs ||
838
- mergedIdenticalBlocks);
958
+ mergedIdenticalBlocks || droppedRedundantArguments );
839
959
}
0 commit comments