@@ -810,10 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
810
810
811
811
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812
812
813
-
814
- // / Find the non constant dim in a linalgOp. This is used for finding contiguous
815
- // / loads and it is expected that only one dim will be non constant, if thats
816
- // / not the case this function will assert.
813
+ // / Find the non-unit dim in a linalgOp.
814
+ // / When executing this hook, it is expected that only one dim will be non-unit.
815
+ // / Other cases (i.e. reading n-D vectors) should've been labelled as gather
816
+ // / loads before calling this method. This is used for finding contiguous loads
817
+ // / (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
818
+ // / this condition is expected to hold for statically shaped Linalg Ops only.
817
819
static uint64_t getNonUnitLoopDim (LinalgOp linalgOp) {
818
820
uint64_t nonUnitDim = 0 ;
819
821
uint64_t countNonUnitDim = 0 ;
@@ -823,8 +825,10 @@ static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
823
825
countNonUnitDim++;
824
826
}
825
827
}
826
- assert (countNonUnitDim == 1 &&
827
- " Expected only one non unit loop dim in this linalg op" );
828
+
829
+ assert (linalgOp.hasDynamicShape () ||
830
+ countNonUnitDim == 1 && " For statically shaped Linalg Ops, only one "
831
+ " non-unit loop dim is expected" );
828
832
return nonUnitDim;
829
833
}
830
834
@@ -908,6 +912,8 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
908
912
Operation *defOp = val.getDefiningOp ();
909
913
assert (defOp && " This is neither a block argument nor an operation result" );
910
914
915
+ // Given the assumption on the loop ranges above, we expect only 1 non-unit
916
+ // loop dim.
911
917
auto nonUnitLoopDim = getNonUnitLoopDim (linalgOp);
912
918
913
919
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
0 commit comments