Skip to content

Commit ac7fc12

Browse files
committed
tensor fixes
1 parent 8075f0d commit ac7fc12

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/IRMapping.h"
2222
#include "mlir/IR/Matchers.h"
2323
#include "mlir/IR/OpDefinition.h"
24+
#include "mlir/IR/TensorEncoding.h"
2425
#include "mlir/IR/TypeUtilities.h"
2526
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2627
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
16221623
currentDim += dim;
16231624
}
16241625

1625-
return RankedTensorType::get(newShape, type.getElementType());
1626+
auto encoding = type.getEncoding();
1627+
if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
1628+
auto ignoreError = [&] {
1629+
auto emitter = mlir::emitError(UnknownLoc::get(type.getContext()));
1630+
emitter.abandon();
1631+
return emitter;
1632+
};
1633+
if (failed(
1634+
v.verifyEncoding(newShape, type.getElementType(), ignoreError))) {
1635+
// strip the encoding if it is not valid for the new shape.
1636+
encoding = Attribute();
1637+
}
1638+
}
1639+
return RankedTensorType::get(newShape, type.getElementType(), encoding);
16261640
}
16271641

16281642
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1902,7 +1916,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
19021916
assert(static_cast<int64_t>(staticSizes.size()) ==
19031917
sourceTensorType.getRank() &&
19041918
"unexpected staticSizes not equal to rank of source");
1905-
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
1919+
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
1920+
sourceTensorType.getEncoding());
19061921
}
19071922

19081923
RankedTensorType ExtractSliceOp::inferResultType(
@@ -1943,7 +1958,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
19431958
if (!dimsToProject.test(pos))
19441959
projectedShape.push_back(shape[pos]);
19451960
inferredType =
1946-
RankedTensorType::get(projectedShape, inferredType.getElementType());
1961+
RankedTensorType::get(projectedShape, inferredType.getElementType(),
1962+
inferredType.getEncoding());
19471963
}
19481964
return inferredType;
19491965
}
@@ -2663,8 +2679,8 @@ struct InsertSliceOpSourceCastInserter final
26632679
if (!hasValidSizesOffsets(newSrcShape))
26642680
return failure();
26652681

2666-
RankedTensorType newSrcType =
2667-
RankedTensorType::get(newSrcShape, srcType.getElementType());
2682+
RankedTensorType newSrcType = RankedTensorType::get(
2683+
newSrcShape, srcType.getElementType(), srcType.getEncoding());
26682684
if (srcType == newSrcType ||
26692685
!preservesStaticInformation(srcType, newSrcType) ||
26702686
!tensor::CastOp::areCastCompatible(srcType, newSrcType))
@@ -2815,7 +2831,8 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
28152831
}
28162832
}
28172833

2818-
return RankedTensorType::get(inferredShape, sourceType.getElementType());
2834+
return RankedTensorType::get(inferredShape, sourceType.getElementType(),
2835+
sourceType.getEncoding());
28192836
}
28202837

28212838
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3601,9 +3618,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
36013618
"tiling factors must equal the number of dimensions to tile");
36023619
}
36033620

3604-
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3605-
? packOrUnPack.getDestType()
3606-
: packOrUnPack.getSourceType();
3621+
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
3622+
? packOrUnPack.getDestType()
3623+
: packOrUnPack.getSourceType();
36073624
size_t packedRank = packedType.getRank();
36083625
// Require output rank to match input rank + number of blocking factors.
36093626
if (unpackedRank + mixedTiles.size() != packedRank) {
@@ -3870,7 +3887,8 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
38703887
ArrayRef<int64_t> outerDimsPerm) {
38713888
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
38723889
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3873-
return RankedTensorType::get(resultShape, sourceType.getElementType());
3890+
return RankedTensorType::get(resultShape, sourceType.getElementType(),
3891+
sourceType.getEncoding());
38743892
}
38753893

38763894
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
122122
// CHECK-LABEL: func.func @linalg_copy(
123123
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124124
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
125-
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
126-
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
127-
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
128-
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
129-
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
130-
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
131-
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
125+
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
126+
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
127+
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
128+
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
129+
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
130+
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
131+
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
132132
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
133133
// CHECK: }
134134

0 commit comments

Comments
 (0)