@@ -985,13 +985,15 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
985
985
for (auto [count, op] : enumerate(entryBlock.getOperations ())) {
986
986
// We first look for operands that are placeholders for initially legal
987
987
// arguments.
988
+ Operation &curOp = op;
988
989
for (auto [operandIdx, operandVal] : llvm::enumerate (op.getOperands ())) {
989
990
Operation *operandOp = operandVal.getDefiningOp ();
990
- auto it = tmpOps.find (operandOp);
991
- if (it != tmpOps. end ())
992
- rewriter.modifyOpInPlace (&op , [&] {
993
- op .setOperand (operandIdx , newFuncOp.getArgument (it->second ));
991
+ if ( auto it = tmpOps.find (operandOp); it != tmpOps. end ()) {
992
+ size_t idx = operandIdx;
993
+ rewriter.modifyOpInPlace (&curOp , [&curOp, &newFuncOp, it, idx ] {
994
+ curOp .setOperand (idx , newFuncOp.getArgument (it->second ));
994
995
});
996
+ }
995
997
}
996
998
// Since all newly created operations are in the beginning, reaching the
997
999
// end of them means that any later `vector.insert_strided_slice` should
@@ -1000,8 +1002,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
1000
1002
continue ;
1001
1003
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1002
1004
size_t unrolledInputNo = unrolledInputNums[idx];
1003
- rewriter.modifyOpInPlace (&op , [&] {
1004
- op .setOperand (0 , newFuncOp.getArgument (unrolledInputNo));
1005
+ rewriter.modifyOpInPlace (&curOp , [&] {
1006
+ curOp .setOperand (0 , newFuncOp.getArgument (unrolledInputNo));
1005
1007
});
1006
1008
++idx;
1007
1009
}
0 commit comments