Skip to content

Commit dba24f6

Browse files
committed
[mlir] Change tensor.extract/insert to take static/dynamic indices.
This changes the ODS of `tensor.extract/insert` op. Some new builder methods are added and the verifiers/canonicalizers are updated. One of the canonicalization pattern of `shape.shape_of` is also updated.
1 parent 9325381 commit dba24f6

File tree

8 files changed

+242
-25
lines changed

8 files changed

+242
-25
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,37 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
332332
```mlir
333333
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
334334
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
335+
%6 = tensor.extract %rt[3, 4] : tensor<?x?xi32>
336+
%7 = tensor.extract %rt[%1, 4] : tensor<?x?xi32>
335337
```
336338
}];
337339

338-
let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
340+
let arguments = (ins
341+
AnyRankedTensor:$tensor,
342+
Variadic<Index>:$indices,
343+
DenseI64ArrayAttr:$static_indices
344+
);
339345
let results = (outs AnyType:$result);
340-
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
346+
let assemblyFormat = [{
347+
$tensor ``
348+
custom<DynamicIndexList>($indices, $static_indices)
349+
attr-dict `:` type($tensor)
350+
}];
351+
352+
let builders = [
353+
// Build an ExtractOp with mixed static and dynamic indexes.
354+
OpBuilder<(ins "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
355+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
356+
// Build an ExtractOp with mixed static, dynamic indexes and inferred result type.
357+
OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
358+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
359+
// Build an ExtractOp with dynamic indexes.
360+
OpBuilder<(ins "Value":$source, CArg<"ValueRange", "{}">:$indexes,
361+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
362+
// Build an ExtractOp with dynamic indexes and inferred result type.
363+
OpBuilder<(ins "Type":$resultType, "Value":$source, CArg<"ValueRange", "{}">:$indexes,
364+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
365+
];
341366

342367
let hasCanonicalizer = 1;
343368
let hasFolder = 1;
@@ -808,16 +833,35 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
808833

809834
let arguments = (ins AnyType:$scalar,
810835
AnyRankedTensor:$dest,
811-
Variadic<Index>:$indices);
836+
Variadic<Index>:$indices,
837+
DenseI64ArrayAttr:$static_indices
838+
);
812839
let results = (outs AnyRankedTensor:$result);
813840
let assemblyFormat = [{
814-
$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
841+
$scalar `into`
842+
$dest `` custom<DynamicIndexList>($indices, $static_indices)
843+
attr-dict `:` type($dest)
815844
}];
816845

817846
let extraClassDeclaration = [{
818847
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
819848
}];
820849

850+
let builders = [
851+
// Build an InsertOp with mixed static and dynamic indexes.
852+
OpBuilder<(ins "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
853+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
854+
// Build an InsertOp with mixed static, dynamic indexes and inferred result type.
855+
OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
856+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
857+
// Build an InsertOp with dynamic indexes.
858+
OpBuilder<(ins "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
859+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
860+
// Build an InsertOp with dynamic indexes and inferred result type.
861+
OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
862+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
863+
];
864+
821865
let hasFolder = 1;
822866
let hasVerifier = 1;
823867
}

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,32 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
17361736
}
17371737
};
17381738

1739+
struct ExtractFromShapeOfExtentTensor
1740+
: public OpRewritePattern<tensor::ExtractOp> {
1741+
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1742+
1743+
LogicalResult matchAndRewrite(tensor::ExtractOp op,
1744+
PatternRewriter &rewriter) const override {
1745+
auto tensorShapeOfOp = op.getTensor().getDefiningOp<shape::ShapeOfOp>();
1746+
if (!tensorShapeOfOp)
1747+
return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of");
1748+
1749+
int64_t staticIndice = op.getStaticIndices()[0];
1750+
Type indexType = rewriter.getIndexType();
1751+
Value indice =
1752+
staticIndice != ShapedType::kDynamic
1753+
? tensorShapeOfOp->getDialect()
1754+
->materializeConstant(
1755+
rewriter, IntegerAttr::get(indexType, staticIndice),
1756+
indexType, op.getLoc())
1757+
->getResult(0)
1758+
: op.getIndices()[0];
1759+
rewriter.replaceOpWithNewOp<tensor::DimOp>(op, tensorShapeOfOp.getArg(),
1760+
indice);
1761+
return success();
1762+
}
1763+
};
1764+
17391765
// Canonicalize
17401766
// ```
17411767
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>

mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,3 @@ def SizeToIndexToSizeCanonicalization : Pat<
4444
def TensorCastConstShape : Pat <
4545
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
4646
[(HasStaticShape $res)]>;
47-
48-
// tensor.extract from shape_of -> tensor.dim. We can take the first index
49-
// because shape_of always returns a 1D tensor.
50-
def ExtractFromShapeOfExtentTensor : Pat<
51-
(Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
52-
(Tensor_DimOp $arg, (TakeFront $indices))>;

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ using llvm::divideCeilSigned;
3939
using llvm::divideFloorSigned;
4040
using llvm::mod;
4141

42+
static LogicalResult
43+
checkTensorRankMatchIndices(Value tensor, ValueRange dynamicIndices,
44+
ArrayRef<int64_t> staticIndices) {
45+
auto tensorType = llvm::cast<RankedTensorType>(tensor.getType());
46+
int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) {
47+
return element == ShapedType::kDynamic;
48+
});
49+
if (tensorType.getRank() != staticIndices.size() ||
50+
dynamicDimCount != static_cast<int64_t>(dynamicIndices.size()))
51+
return LogicalResult::failure();
52+
return LogicalResult::success();
53+
}
54+
4255
/// Materialize a single constant operation from a given attribute value with
4356
/// the desired resultant type.
4457
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1120,10 +1133,49 @@ void ExtractOp::getAsmResultNames(
11201133
setNameFn(getResult(), "extracted");
11211134
}
11221135

1136+
// Build an ExtractOp with mixed static and dynamic indexes.
1137+
void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
1138+
ArrayRef<OpFoldResult> indices,
1139+
ArrayRef<NamedAttribute> attrs) {
1140+
Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
1141+
build(b, result, resultType, tensor, indices, attrs);
1142+
}
1143+
1144+
// Build an ExtractOp with mixed static, dynamic indexes and inferred result
1145+
// Type.
1146+
void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
1147+
Value tensor, ArrayRef<OpFoldResult> indices,
1148+
ArrayRef<NamedAttribute> attrs) {
1149+
SmallVector<int64_t> staticIndices;
1150+
SmallVector<Value> dynamicIndices;
1151+
dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
1152+
result.addAttributes(attrs);
1153+
build(b, result, resultType, tensor, dynamicIndices,
1154+
b.getDenseI64ArrayAttr(staticIndices));
1155+
}
1156+
1157+
// Build an ExtractOp with dynamic indexes and inferred result type.
1158+
void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
1159+
Value tensor, ValueRange indices,
1160+
ArrayRef<NamedAttribute> attrs) {
1161+
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
1162+
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
1163+
build(b, result, resultType, tensor, indicesValues, attrs);
1164+
}
1165+
1166+
// Build an ExtractOp with dynamic indexes.
1167+
void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
1168+
ValueRange indices, ArrayRef<NamedAttribute> attrs) {
1169+
Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
1170+
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
1171+
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
1172+
build(b, result, resultType, tensor, indicesValues, attrs);
1173+
}
1174+
11231175
LogicalResult ExtractOp::verify() {
11241176
// Verify the # indices match if we have a ranked type.
1125-
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1126-
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1177+
if (failed(checkTensorRankMatchIndices(getTensor(), getIndices(),
1178+
getStaticIndices())))
11271179
return emitOpError("incorrect number of indices for extract_element");
11281180
return success();
11291181
}
@@ -1137,12 +1189,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
11371189

11381190
// Collect the constant indices into the tensor.
11391191
SmallVector<uint64_t, 8> indices;
1140-
for (Attribute indice : adaptor.getIndices()) {
1141-
if (!indice || !llvm::isa<IntegerAttr>(indice))
1142-
return {};
1143-
indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1192+
auto dynamicIndicesIt = adaptor.getIndices().begin();
1193+
for (int64_t i : getStaticIndices()) {
1194+
if (i != ShapedType::kDynamic) {
1195+
indices.push_back(i);
1196+
} else {
1197+
Attribute indice = *dynamicIndicesIt;
1198+
if (!indice || !llvm::isa<IntegerAttr>(indice))
1199+
return {};
1200+
indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1201+
dynamicIndicesIt++;
1202+
}
11441203
}
1145-
11461204
// Fold extract(from_elements(...)).
11471205
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
11481206
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
@@ -1354,10 +1412,48 @@ void InsertOp::getAsmResultNames(
13541412
setNameFn(getResult(), "inserted");
13551413
}
13561414

1415+
// Build an ExtractOp with mixed static and dynamic indexes.
1416+
void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
1417+
Value dest, ArrayRef<OpFoldResult> indices,
1418+
ArrayRef<NamedAttribute> attrs) {
1419+
build(b, result, dest.getType(), scalar, dest, indices, attrs);
1420+
}
1421+
1422+
// Build an InsertOp with mixed static, dynamic indexes and inferred result
1423+
// Type.
1424+
void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
1425+
Value scalar, Value dest, ArrayRef<OpFoldResult> indices,
1426+
ArrayRef<NamedAttribute> attrs) {
1427+
SmallVector<int64_t> staticIndices;
1428+
SmallVector<Value> dynamicIndices;
1429+
dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
1430+
result.addAttributes(attrs);
1431+
build(b, result, resultType, scalar, dest, dynamicIndices,
1432+
b.getDenseI64ArrayAttr(staticIndices));
1433+
}
1434+
1435+
// Build an ExtractOp with dynamic indexes and inferred result type.
1436+
void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
1437+
Value scalar, Value dest, ValueRange indices,
1438+
ArrayRef<NamedAttribute> attrs) {
1439+
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
1440+
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
1441+
build(b, result, resultType, scalar, dest, indicesValues, attrs);
1442+
}
1443+
1444+
// Build an InsertOp with dynamic indexes.
1445+
void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
1446+
Value dest, ValueRange indices,
1447+
ArrayRef<NamedAttribute> attrs) {
1448+
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
1449+
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
1450+
build(b, result, dest.getType(), scalar, dest, indicesValues, attrs);
1451+
}
1452+
13571453
LogicalResult InsertOp::verify() {
13581454
// Verify the # indices match if we have a ranked type.
1359-
auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1360-
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1455+
if (failed(checkTensorRankMatchIndices(getDest(), getIndices(),
1456+
getStaticIndices())))
13611457
return emitOpError("incorrect number of indices");
13621458
return success();
13631459
}

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,19 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
15191519
return %result : index
15201520
}
15211521

1522+
// -----
1523+
1524+
// CHECK-LABEL: func @extract_shapeof_static_indice
1525+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64>
1526+
func.func @extract_shapeof_static_indice(%arg0 : tensor<?x?xf64>) -> index {
1527+
// CHECK: %[[C1:.*]] = arith.constant 1
1528+
%shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
1529+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
1530+
%result = tensor.extract %shape[1] : tensor<2xindex>
1531+
// CHECK: return %[[DIM]]
1532+
return %result : index
1533+
}
1534+
15221535

15231536
// -----
15241537

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,12 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
137137
// -----
138138

139139
// CHECK-LABEL: func @fold_extract
140-
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
140+
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, i32, complex<f32>) {
141141
%const_0 = arith.constant 0 : index
142142
%const_1 = arith.constant 1 : index
143143
%const_3 = arith.constant 3 : index
144144
// CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
145+
// CHECK-DAG: [[CNEG1:%.+]] = arith.constant -1 : i32
145146
// CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
146147
// CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
147148

@@ -162,13 +163,16 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
162163
%3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
163164
%ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
164165

166+
// Fold an extract into a dense tensor with mixed dynamic and static indexes.
167+
%ext_5 = tensor.extract %3[%const_1, 0, 2] : tensor<2x1x4xi32>
168+
165169
// Fold an extract into a complex constant.
166170
// CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
167171
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
168-
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
172+
%ext_6 = tensor.extract %4[] : tensor<complex<f32>>
169173

170-
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
171-
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
174+
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[CNEG1]], [[C5]]
175+
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6: f32, f16, f16, i32, i32, complex<f32>
172176
}
173177

174178
// -----

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,56 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
6464

6565
// -----
6666

67-
func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
67+
func.func @extract_too_few_indices(%arg0: tensor<?xf32>) {
6868
// expected-error@+1 {{incorrect number of indices for extract_element}}
6969
%0 = tensor.extract %arg0[] : tensor<?xf32>
7070
return
7171
}
7272

7373
// -----
7474

75-
func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
75+
func.func @extract_too_many_static_indices(%arg0: tensor<?xf32>) {
76+
// expected-error@+1 {{incorrect number of indices for extract_element}}
77+
%0 = tensor.extract %arg0[2, 3] : tensor<?xf32>
78+
return
79+
}
80+
81+
// -----
82+
83+
func.func @extract_too_many_mixed_indices(%arg0: tensor<?xf32>) {
84+
%c1 = arith.constant 1 : index
85+
// expected-error@+1 {{incorrect number of indices for extract_element}}
86+
%0 = tensor.extract %arg0[%c1, 2, 3] : tensor<?xf32>
87+
return
88+
}
89+
90+
// -----
91+
92+
func.func @insert_too_few_indices(%arg0: f32, %arg1: tensor<?xf32>) {
7693
// expected-error@+1 {{incorrect number of indices}}
7794
%0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
7895
return
7996
}
8097

8198
// -----
8299

100+
func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
101+
// expected-error@+1 {{incorrect number of indices}}
102+
%0 = tensor.insert %arg0 into %arg1[2, 3] : tensor<?xf32>
103+
return
104+
}
105+
106+
// -----
107+
108+
func.func @insert_too_many_mixed_indices(%arg0: f32, %arg1: tensor<?xf32>) {
109+
%c1 = arith.constant 1 : index
110+
// expected-error@+1 {{incorrect number of indices}}
111+
%0 = tensor.insert %arg0 into %arg1[%c1, 2, 3] : tensor<?xf32>
112+
return
113+
}
114+
115+
// -----
116+
83117
func.func @tensor.from_elements_wrong_result_type() {
84118
// expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}}
85119
%c0 = arith.constant 0 : i32

0 commit comments

Comments
 (0)