|
21 | 21 | #include "mlir/IR/IRMapping.h"
|
22 | 22 | #include "mlir/IR/Matchers.h"
|
23 | 23 | #include "mlir/IR/OpDefinition.h"
|
| 24 | +#include "mlir/IR/TensorEncoding.h" |
24 | 25 | #include "mlir/IR/TypeUtilities.h"
|
25 | 26 | #include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
26 | 27 | #include "mlir/Interfaces/LoopLikeInterface.h"
|
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
|
1622 | 1623 | currentDim += dim;
|
1623 | 1624 | }
|
1624 | 1625 |
|
1625 |
| - return RankedTensorType::get(newShape, type.getElementType()); |
| 1626 | + auto encoding = type.getEncoding(); |
| 1627 | + if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) { |
| 1628 | + auto ignoreError = [&] { |
| 1629 | + auto emitter = mlir::emitError(UnknownLoc::get(type.getContext())); |
| 1630 | + emitter.abandon(); |
| 1631 | + return emitter; |
| 1632 | + }; |
| 1633 | + if (failed( |
| 1634 | + v.verifyEncoding(newShape, type.getElementType(), ignoreError))) { |
| 1635 | + // strip the encoding if it is not valid for the new shape. |
| 1636 | + encoding = Attribute(); |
| 1637 | + } |
| 1638 | + } |
| 1639 | + return RankedTensorType::get(newShape, type.getElementType(), encoding); |
1626 | 1640 | }
|
1627 | 1641 |
|
1628 | 1642 | void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
|
@@ -1902,7 +1916,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
|
1902 | 1916 | assert(static_cast<int64_t>(staticSizes.size()) ==
|
1903 | 1917 | sourceTensorType.getRank() &&
|
1904 | 1918 | "unexpected staticSizes not equal to rank of source");
|
1905 |
| - return RankedTensorType::get(staticSizes, sourceTensorType.getElementType()); |
| 1919 | + return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(), |
| 1920 | + sourceTensorType.getEncoding()); |
1906 | 1921 | }
|
1907 | 1922 |
|
1908 | 1923 | RankedTensorType ExtractSliceOp::inferResultType(
|
@@ -1943,7 +1958,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
|
1943 | 1958 | if (!dimsToProject.test(pos))
|
1944 | 1959 | projectedShape.push_back(shape[pos]);
|
1945 | 1960 | inferredType =
|
1946 |
| - RankedTensorType::get(projectedShape, inferredType.getElementType()); |
| 1961 | + RankedTensorType::get(projectedShape, inferredType.getElementType(), |
| 1962 | + inferredType.getEncoding()); |
1947 | 1963 | }
|
1948 | 1964 | return inferredType;
|
1949 | 1965 | }
|
@@ -2663,8 +2679,8 @@ struct InsertSliceOpSourceCastInserter final
|
2663 | 2679 | if (!hasValidSizesOffsets(newSrcShape))
|
2664 | 2680 | return failure();
|
2665 | 2681 |
|
2666 |
| - RankedTensorType newSrcType = |
2667 |
| - RankedTensorType::get(newSrcShape, srcType.getElementType()); |
| 2682 | + RankedTensorType newSrcType = RankedTensorType::get( |
| 2683 | + newSrcShape, srcType.getElementType(), srcType.getEncoding()); |
2668 | 2684 | if (srcType == newSrcType ||
|
2669 | 2685 | !preservesStaticInformation(srcType, newSrcType) ||
|
2670 | 2686 | !tensor::CastOp::areCastCompatible(srcType, newSrcType))
|
@@ -2815,7 +2831,8 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
|
2815 | 2831 | }
|
2816 | 2832 | }
|
2817 | 2833 |
|
2818 |
| - return RankedTensorType::get(inferredShape, sourceType.getElementType()); |
| 2834 | + return RankedTensorType::get(inferredShape, sourceType.getElementType(), |
| 2835 | + sourceType.getEncoding()); |
2819 | 2836 | }
|
2820 | 2837 |
|
2821 | 2838 | void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
@@ -3601,9 +3618,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
|
3601 | 3618 | "tiling factors must equal the number of dimensions to tile");
|
3602 | 3619 | }
|
3603 | 3620 |
|
3604 |
| - ShapedType packedType = (std::is_same<OpTy, PackOp>::value) |
3605 |
| - ? packOrUnPack.getDestType() |
3606 |
| - : packOrUnPack.getSourceType(); |
| 3621 | + RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value) |
| 3622 | + ? packOrUnPack.getDestType() |
| 3623 | + : packOrUnPack.getSourceType(); |
3607 | 3624 | size_t packedRank = packedType.getRank();
|
3608 | 3625 | // Require output rank to match input rank + number of blocking factors.
|
3609 | 3626 | if (unpackedRank + mixedTiles.size() != packedRank) {
|
@@ -3870,7 +3887,8 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
|
3870 | 3887 | ArrayRef<int64_t> outerDimsPerm) {
|
3871 | 3888 | SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
|
3872 | 3889 | sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
|
3873 |
| - return RankedTensorType::get(resultShape, sourceType.getElementType()); |
| 3890 | + return RankedTensorType::get(resultShape, sourceType.getElementType(), |
| 3891 | + sourceType.getEncoding()); |
3874 | 3892 | }
|
3875 | 3893 |
|
3876 | 3894 | Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
|
|
0 commit comments