@@ -917,7 +917,7 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
917
917
// / This accounts for cases where there are multiple unit-dims, but only a
918
918
// / subset of those are dropped. For MemRefTypes these can be disambiguated
919
919
// / using the strides. If a dimension is dropped the stride must be dropped too.
920
- static std::optional <llvm::SmallBitVector>
920
+ static FailureOr <llvm::SmallBitVector>
921
921
computeMemRefRankReductionMask (MemRefType originalType, MemRefType reducedType,
922
922
ArrayRef<OpFoldResult> sizes) {
923
923
llvm::SmallBitVector unusedDims (originalType.getRank ());
@@ -941,7 +941,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941
941
getStridesAndOffset (originalType, originalStrides, originalOffset)) ||
942
942
failed (
943
943
getStridesAndOffset (reducedType, candidateStrides, candidateOffset)))
944
- return std::nullopt ;
944
+ return failure () ;
945
945
946
946
// For memrefs, a dimension is truly dropped if its corresponding stride is
947
947
// also dropped. This is particularly important when more than one of the dims
@@ -976,22 +976,22 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
976
976
candidateStridesNumOccurences[originalStride]) {
977
977
// This should never happen. Cant have a stride in the reduced rank type
978
978
// that wasnt in the original one.
979
- return std::nullopt ;
979
+ return failure () ;
980
980
}
981
981
}
982
982
983
983
if ((int64_t )unusedDims.count () + reducedType.getRank () !=
984
984
originalType.getRank ())
985
- return std::nullopt ;
985
+ return failure () ;
986
986
return unusedDims;
987
987
}
988
988
989
989
llvm::SmallBitVector SubViewOp::getDroppedDims () {
990
990
MemRefType sourceType = getSourceType ();
991
991
MemRefType resultType = getType ();
992
- std::optional <llvm::SmallBitVector> unusedDims =
992
+ FailureOr <llvm::SmallBitVector> unusedDims =
993
993
computeMemRefRankReductionMask (sourceType, resultType, getMixedSizes ());
994
- assert (unusedDims && " unable to find unused dims of subview" );
994
+ assert (succeeded ( unusedDims) && " unable to find unused dims of subview" );
995
995
return *unusedDims;
996
996
}
997
997
@@ -2745,7 +2745,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2745
2745
// / For ViewLikeOpInterface.
2746
2746
Value SubViewOp::getViewSource () { return getSource (); }
2747
2747
2748
- // / Return true if t1 and t2 have equal offsets (both dynamic or of same
2748
+ // / Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2749
2749
// / static value).
2750
2750
static bool haveCompatibleOffsets (MemRefType t1, MemRefType t2) {
2751
2751
int64_t t1Offset, t2Offset;
@@ -2755,56 +2755,41 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2755
2755
return succeeded (res1) && succeeded (res2) && t1Offset == t2Offset;
2756
2756
}
2757
2757
2758
- // / Checks if `original` Type type can be rank reduced to `reduced` type.
2759
- // / This function is slight variant of `is subsequence` algorithm where
2760
- // / not matching dimension must be 1.
2761
- static SliceVerificationResult
2762
- isRankReducedMemRefType (MemRefType originalType,
2763
- MemRefType candidateRankReducedType,
2764
- ArrayRef<OpFoldResult> sizes) {
2765
- auto partialRes = isRankReducedType (originalType, candidateRankReducedType);
2766
- if (partialRes != SliceVerificationResult::Success)
2767
- return partialRes;
2768
-
2769
- auto optionalUnusedDimsMask = computeMemRefRankReductionMask (
2770
- originalType, candidateRankReducedType, sizes);
2771
-
2772
- // Sizes cannot be matched in case empty vector is returned.
2773
- if (!optionalUnusedDimsMask)
2774
- return SliceVerificationResult::LayoutMismatch;
2775
-
2776
- if (originalType.getMemorySpace () !=
2777
- candidateRankReducedType.getMemorySpace ())
2778
- return SliceVerificationResult::MemSpaceMismatch;
2779
-
2780
- // No amount of stride dropping can reconcile incompatible offsets.
2781
- if (!haveCompatibleOffsets (originalType, candidateRankReducedType))
2782
- return SliceVerificationResult::LayoutMismatch;
2783
-
2784
- return SliceVerificationResult::Success;
2758
+ // / Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759
+ // / static value).
2760
+ static bool haveCompatibleStrides (MemRefType t1, MemRefType t2) {
2761
+ int64_t t1Offset, t2Offset;
2762
+ SmallVector<int64_t > t1Strides, t2Strides;
2763
+ auto res1 = getStridesAndOffset (t1, t1Strides, t1Offset);
2764
+ auto res2 = getStridesAndOffset (t2, t2Strides, t2Offset);
2765
+ if (failed (res1) || failed (res2))
2766
+ return false ;
2767
+ for (auto [s1, s2] : llvm::zip_equal (t1Strides, t2Strides))
2768
+ if (s1 != s2)
2769
+ return false ;
2770
+ return true ;
2785
2771
}
2786
2772
2787
- template <typename OpTy>
2788
2773
static LogicalResult produceSubViewErrorMsg (SliceVerificationResult result,
2789
- OpTy op, Type expectedType) {
2774
+ Operation * op, Type expectedType) {
2790
2775
auto memrefType = llvm::cast<ShapedType>(expectedType);
2791
2776
switch (result) {
2792
2777
case SliceVerificationResult::Success:
2793
2778
return success ();
2794
2779
case SliceVerificationResult::RankTooLarge:
2795
- return op. emitError (" expected result rank to be smaller or equal to " )
2780
+ return op-> emitError (" expected result rank to be smaller or equal to " )
2796
2781
<< " the source rank. " ;
2797
2782
case SliceVerificationResult::SizeMismatch:
2798
- return op. emitError (" expected result type to be " )
2783
+ return op-> emitError (" expected result type to be " )
2799
2784
<< expectedType
2800
2785
<< " or a rank-reduced version. (mismatch of result sizes) " ;
2801
2786
case SliceVerificationResult::ElemTypeMismatch:
2802
- return op. emitError (" expected result element type to be " )
2787
+ return op-> emitError (" expected result element type to be " )
2803
2788
<< memrefType.getElementType ();
2804
2789
case SliceVerificationResult::MemSpaceMismatch:
2805
- return op. emitError (" expected result and source memory spaces to match." );
2790
+ return op-> emitError (" expected result and source memory spaces to match." );
2806
2791
case SliceVerificationResult::LayoutMismatch:
2807
- return op. emitError (" expected result type to be " )
2792
+ return op-> emitError (" expected result type to be " )
2808
2793
<< expectedType
2809
2794
<< " or a rank-reduced version. (mismatch of result layout) " ;
2810
2795
}
@@ -2826,13 +2811,46 @@ LogicalResult SubViewOp::verify() {
2826
2811
if (!isStrided (baseType))
2827
2812
return emitError (" base type " ) << baseType << " is not strided" ;
2828
2813
2829
- // Verify result type against inferred type.
2830
- auto expectedType = SubViewOp::inferResultType (
2831
- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
2814
+ // Compute the expected result type, assuming that there are no rank
2815
+ // reductions.
2816
+ auto expectedType = cast<MemRefType>(SubViewOp::inferResultType (
2817
+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ()));
2818
+
2819
+ // Verify all properties of a shaped type: rank, element type and dimension
2820
+ // sizes. This takes into account potential rank reductions.
2821
+ auto shapedTypeVerification = isRankReducedType (
2822
+ /* originalType=*/ expectedType, /* candidateReducedType=*/ subViewType);
2823
+ if (shapedTypeVerification != SliceVerificationResult::Success)
2824
+ return produceSubViewErrorMsg (shapedTypeVerification, *this , expectedType);
2825
+
2826
+ // Make sure that the memory space did not change.
2827
+ if (expectedType.getMemorySpace () != subViewType.getMemorySpace ())
2828
+ return produceSubViewErrorMsg (SliceVerificationResult::MemSpaceMismatch,
2829
+ *this , expectedType);
2830
+
2831
+ // Verify the offset of the layout map.
2832
+ if (!haveCompatibleOffsets (expectedType, subViewType))
2833
+ return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2834
+ *this , expectedType);
2835
+
2836
+ // The only thing that's left to verify now are the strides. First, compute
2837
+ // the unused dimensions due to rank reductions. We have to look at sizes and
2838
+ // strides to decide which dimensions were dropped. This function also
2839
+ // verifies strides in case of rank reductions.
2840
+ auto unusedDims = computeMemRefRankReductionMask (expectedType, subViewType,
2841
+ getMixedSizes ());
2842
+ if (failed (unusedDims))
2843
+ return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2844
+ *this , expectedType);
2845
+
2846
+ // Strides must match if there are no rank reductions.
2847
+ // TODO: Verify strides when there are rank reductions. Strides are partially
2848
+ // checked in `computeMemRefRankReductionMask`.
2849
+ if (unusedDims->none () && !haveCompatibleStrides (expectedType, subViewType))
2850
+ return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2851
+ *this , expectedType);
2832
2852
2833
- auto result = isRankReducedMemRefType (llvm::cast<MemRefType>(expectedType),
2834
- subViewType, getMixedSizes ());
2835
- return produceSubViewErrorMsg (result, *this , expectedType);
2853
+ return success ();
2836
2854
}
2837
2855
2838
2856
raw_ostream &mlir::operator <<(raw_ostream &os, const Range &range) {
@@ -2882,11 +2900,9 @@ static MemRefType getCanonicalSubViewResultType(
2882
2900
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2883
2901
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType (
2884
2902
sourceType, mixedOffsets, mixedSizes, mixedStrides));
2885
- std::optional<llvm::SmallBitVector> unusedDims =
2886
- computeMemRefRankReductionMask (currentSourceType, currentResultType,
2887
- mixedSizes);
2888
- // Return nullptr as failure mode.
2889
- if (!unusedDims)
2903
+ FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
2904
+ currentSourceType, currentResultType, mixedSizes);
2905
+ if (failed (unusedDims))
2890
2906
return nullptr ;
2891
2907
2892
2908
auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout ());
0 commit comments