@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
136
136
def ForOp : SCF_Op<"for",
137
137
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
138
138
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
139
- "getSingleInductionVar ", "getSingleLowerBound ", "getSingleStep ",
140
- "getSingleUpperBound ", "getYieldedValuesMutable",
139
+ "getLoopInductionVars ", "getLoopLowerBounds ", "getLoopSteps ",
140
+ "getLoopUpperBounds ", "getYieldedValuesMutable",
141
141
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142
142
"yieldTiledValuesAndReplace"]>,
143
143
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
301
301
AttrSizedOperandSegments,
302
302
AutomaticAllocationScope,
303
303
DeclareOpInterfaceMethods<LoopLikeOpInterface,
304
- ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar ",
305
- "getSingleLowerBound ", "getSingleUpperBound ", "getSingleStep ",
304
+ ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars ",
305
+ "getLoopLowerBounds ", "getLoopUpperBounds ", "getLoopSteps ",
306
306
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
307
307
RecursiveMemoryEffects,
308
308
SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,22 +510,31 @@ def ForallOp : SCF_Op<"forall", [
510
510
];
511
511
512
512
let extraClassDeclaration = [{
513
- // Get lower bounds as OpFoldResult.
513
+ /// Get induction variables.
514
+ SmallVector<Value> getInductionVars() {
515
+ std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();
516
+ assert(maybeInductionVars.has_value() && "expected values");
517
+ return *maybeInductionVars;
518
+ }
519
+ /// Get lower bounds as OpFoldResult.
514
520
SmallVector<OpFoldResult> getMixedLowerBound() {
515
- Builder b(getOperation()->getContext());
516
- return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
521
+ std::optional<SmallVector<OpFoldResult>> maybeLowerBounds = getLoopLowerBounds();
522
+ assert(maybeLowerBounds.has_value() && "expected values");
523
+ return *maybeLowerBounds;
517
524
}
518
525
519
- // Get upper bounds as OpFoldResult.
526
+ /// Get upper bounds as OpFoldResult.
520
527
SmallVector<OpFoldResult> getMixedUpperBound() {
521
- Builder b(getOperation()->getContext());
522
- return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
528
+ std::optional<SmallVector<OpFoldResult>> maybeUpperBounds = getLoopUpperBounds();
529
+ assert(maybeUpperBounds.has_value() && "expected values");
530
+ return *maybeUpperBounds;
523
531
}
524
532
525
- // Get steps as OpFoldResult.
533
+ /// Get steps as OpFoldResult.
526
534
SmallVector<OpFoldResult> getMixedStep() {
527
- Builder b(getOperation()->getContext());
528
- return getMixedValues(getStaticStep(), getDynamicStep(), b);
535
+ std::optional<SmallVector<OpFoldResult>> maybeSteps = getLoopSteps();
536
+ assert(maybeSteps.has_value() && "expected values");
537
+ return *maybeSteps;
529
538
}
530
539
531
540
/// Get lower bounds as values.
@@ -584,10 +593,6 @@ def ForallOp : SCF_Op<"forall", [
584
593
getNumDynamicControlOperands() + getRank());
585
594
}
586
595
587
- ::mlir::ValueRange getInductionVars() {
588
- return getBody()->getArguments().take_front(getRank());
589
- }
590
-
591
596
::mlir::Value getInductionVar(int64_t idx) {
592
597
return getInductionVars()[idx];
593
598
}
@@ -765,8 +770,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
765
770
def ParallelOp : SCF_Op<"parallel",
766
771
[AutomaticAllocationScope,
767
772
AttrSizedOperandSegments,
768
- DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar ",
769
- "getSingleLowerBound ", "getSingleUpperBound ", "getSingleStep "]>,
773
+ DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getLoopInductionVars ",
774
+ "getLoopLowerBounds ", "getLoopUpperBounds ", "getLoopSteps "]>,
770
775
RecursiveMemoryEffects,
771
776
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
772
777
SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -846,8 +851,11 @@ def ParallelOp : SCF_Op<"parallel",
846
851
];
847
852
848
853
let extraClassDeclaration = [{
849
- ValueRange getInductionVars() {
850
- return getBody()->getArguments();
854
+ /// Get induction variables.
855
+ SmallVector<Value> getInductionVars() {
856
+ std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();;
857
+ assert(maybeInductionVars.has_value() && "expected values");
858
+ return *maybeInductionVars;
851
859
}
852
860
unsigned getNumLoops() { return getStep().size(); }
853
861
unsigned getNumReductions() { return getInitVals().size(); }
0 commit comments