@@ -2702,10 +2702,10 @@ void SubViewOp::getAsmResultNames(
2702
2702
// / A subview result type can be fully inferred from the source type and the
2703
2703
// / static representation of offsets, sizes and strides. Special sentinels
2704
2704
// / encode the dynamic case.
2705
- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2706
- ArrayRef<int64_t > staticOffsets,
2707
- ArrayRef<int64_t > staticSizes,
2708
- ArrayRef<int64_t > staticStrides) {
2705
+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2706
+ ArrayRef<int64_t > staticOffsets,
2707
+ ArrayRef<int64_t > staticSizes,
2708
+ ArrayRef<int64_t > staticStrides) {
2709
2709
unsigned rank = sourceMemRefType.getRank ();
2710
2710
(void )rank;
2711
2711
assert (staticOffsets.size () == rank && " staticOffsets length mismatch" );
@@ -2744,10 +2744,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2744
2744
sourceMemRefType.getMemorySpace ());
2745
2745
}
2746
2746
2747
- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2748
- ArrayRef<OpFoldResult> offsets,
2749
- ArrayRef<OpFoldResult> sizes,
2750
- ArrayRef<OpFoldResult> strides) {
2747
+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2748
+ ArrayRef<OpFoldResult> offsets,
2749
+ ArrayRef<OpFoldResult> sizes,
2750
+ ArrayRef<OpFoldResult> strides) {
2751
2751
SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
2752
2752
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2753
2753
dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2763,13 +2763,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2763
2763
staticSizes, staticStrides);
2764
2764
}
2765
2765
2766
- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2767
- MemRefType sourceRankedTensorType,
2768
- ArrayRef<int64_t > offsets,
2769
- ArrayRef<int64_t > sizes,
2770
- ArrayRef<int64_t > strides) {
2771
- auto inferredType = llvm::cast<MemRefType>(
2772
- inferResultType (sourceRankedTensorType, offsets, sizes, strides));
2766
+ MemRefType SubViewOp::inferRankReducedResultType (
2767
+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2768
+ ArrayRef<int64_t > offsets, ArrayRef<int64_t > sizes,
2769
+ ArrayRef<int64_t > strides) {
2770
+ MemRefType inferredType =
2771
+ inferResultType (sourceRankedTensorType, offsets, sizes, strides);
2773
2772
assert (inferredType.getRank () >= static_cast <int64_t >(resultShape.size ()) &&
2774
2773
" expected " );
2775
2774
if (inferredType.getRank () == static_cast <int64_t >(resultShape.size ()))
@@ -2795,11 +2794,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2795
2794
inferredType.getMemorySpace ());
2796
2795
}
2797
2796
2798
- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2799
- MemRefType sourceRankedTensorType,
2800
- ArrayRef<OpFoldResult> offsets,
2801
- ArrayRef<OpFoldResult> sizes,
2802
- ArrayRef<OpFoldResult> strides) {
2797
+ MemRefType SubViewOp::inferRankReducedResultType (
2798
+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2799
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2800
+ ArrayRef<OpFoldResult> strides) {
2803
2801
SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
2804
2802
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2805
2803
dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2826,8 +2824,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
2826
2824
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType ());
2827
2825
// Structuring implementation this way avoids duplication between builders.
2828
2826
if (!resultType) {
2829
- resultType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
2830
- sourceMemRefType, staticOffsets, staticSizes, staticStrides) );
2827
+ resultType = SubViewOp::inferResultType (sourceMemRefType, staticOffsets,
2828
+ staticSizes, staticStrides);
2831
2829
}
2832
2830
result.addAttributes (attrs);
2833
2831
build (b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2992,8 +2990,8 @@ LogicalResult SubViewOp::verify() {
2992
2990
2993
2991
// Compute the expected result type, assuming that there are no rank
2994
2992
// reductions.
2995
- auto expectedType = cast<MemRefType>( SubViewOp::inferResultType (
2996
- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ())) ;
2993
+ MemRefType expectedType = SubViewOp::inferResultType (
2994
+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
2997
2995
2998
2996
// Verify all properties of a shaped type: rank, element type and dimension
2999
2997
// sizes. This takes into account potential rank reductions.
@@ -3075,8 +3073,8 @@ static MemRefType getCanonicalSubViewResultType(
3075
3073
MemRefType currentResultType, MemRefType currentSourceType,
3076
3074
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3077
3075
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3078
- auto nonRankReducedType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
3079
- sourceType, mixedOffsets, mixedSizes, mixedStrides)) ;
3076
+ MemRefType nonRankReducedType = SubViewOp::inferResultType (
3077
+ sourceType, mixedOffsets, mixedSizes, mixedStrides);
3080
3078
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
3081
3079
currentSourceType, currentResultType, mixedSizes);
3082
3080
if (failed (unusedDims))
@@ -3110,9 +3108,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
3110
3108
SmallVector<OpFoldResult> offsets (rank, b.getIndexAttr (0 ));
3111
3109
SmallVector<OpFoldResult> sizes = getMixedSizes (b, loc, memref);
3112
3110
SmallVector<OpFoldResult> strides (rank, b.getIndexAttr (1 ));
3113
- auto targetType =
3114
- llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType (
3115
- targetShape, memrefType, offsets, sizes, strides));
3111
+ MemRefType targetType = SubViewOp::inferRankReducedResultType (
3112
+ targetShape, memrefType, offsets, sizes, strides);
3116
3113
return b.createOrFold <memref::SubViewOp>(loc, targetType, memref, offsets,
3117
3114
sizes, strides);
3118
3115
}
@@ -3256,11 +3253,11 @@ struct SubViewReturnTypeCanonicalizer {
3256
3253
ArrayRef<OpFoldResult> mixedSizes,
3257
3254
ArrayRef<OpFoldResult> mixedStrides) {
3258
3255
// Infer a memref type without taking into account any rank reductions.
3259
- auto resTy = SubViewOp::inferResultType (op. getSourceType (), mixedOffsets,
3260
- mixedSizes, mixedStrides);
3256
+ MemRefType resTy = SubViewOp::inferResultType (
3257
+ op. getSourceType (), mixedOffsets, mixedSizes, mixedStrides);
3261
3258
if (!resTy)
3262
3259
return {};
3263
- MemRefType nonReducedType = cast<MemRefType>( resTy) ;
3260
+ MemRefType nonReducedType = resTy;
3264
3261
3265
3262
// Directly return the non-rank reduced type if there are no dropped dims.
3266
3263
llvm::SmallBitVector droppedDims = op.getDroppedDims ();
0 commit comments