14
14
#include " mlir/Dialect/Utils/IndexingUtils.h"
15
15
#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
16
16
#include " mlir/Dialect/Utils/StaticValueUtils.h"
17
+ #include " mlir/IR/Attributes.h"
17
18
#include " mlir/IR/Builders.h"
18
19
#include " mlir/IR/BuiltinAttributeInterfaces.h"
19
20
#include " mlir/IR/BuiltinTypeInterfaces.h"
27
28
#include " llvm/ADT/DenseSet.h"
28
29
#include " llvm/ADT/STLExtras.h"
29
30
#include " llvm/ADT/SmallBitVector.h"
31
+ #include " llvm/ADT/SmallVector.h"
30
32
#include " llvm/ADT/StringRef.h"
33
+ #include " llvm/Support/Casting.h"
31
34
#include " llvm/Support/MathExtras.h"
32
35
#include < algorithm>
33
36
#include < optional>
@@ -39,6 +42,19 @@ using llvm::divideCeilSigned;
39
42
using llvm::divideFloorSigned;
40
43
using llvm::mod;
41
44
45
+ static LogicalResult
46
+ checkTensorRankMatchIndices (Value tensor, ValueRange dynamicIndices,
47
+ ArrayRef<int64_t > staticIndices) {
48
+ auto tensorType = llvm::cast<RankedTensorType>(tensor.getType ());
49
+ int64_t dynamicDimCount = llvm::count_if (staticIndices, [](int64_t element) {
50
+ return element == ShapedType::kDynamic ;
51
+ });
52
+ if (tensorType.getRank () != staticIndices.size () ||
53
+ dynamicDimCount != static_cast <int64_t >(dynamicIndices.size ()))
54
+ return LogicalResult::failure ();
55
+ return LogicalResult::success ();
56
+ }
57
+
42
58
// / Materialize a single constant operation from a given attribute value with
43
59
// / the desired resultant type.
44
60
Operation *TensorDialect::materializeConstant (OpBuilder &builder,
@@ -1120,10 +1136,49 @@ void ExtractOp::getAsmResultNames(
1120
1136
setNameFn (getResult (), " extracted" );
1121
1137
}
1122
1138
1139
+ // Build an ExtractOp with mixed static and dynamic indexes.
1140
+ void ExtractOp::build (OpBuilder &b, OperationState &result, Value tensor,
1141
+ ArrayRef<OpFoldResult> indices,
1142
+ ArrayRef<NamedAttribute> attrs) {
1143
+ Type resultType = llvm::cast<TensorType>(tensor.getType ()).getElementType ();
1144
+ build (b, result, resultType, tensor, indices, attrs);
1145
+ }
1146
+
1147
+ // Build an ExtractOp with mixed static, dynamic indexes and inferred result
1148
+ // Type.
1149
+ void ExtractOp::build (OpBuilder &b, OperationState &result, Type resultType,
1150
+ Value tensor, ArrayRef<OpFoldResult> indices,
1151
+ ArrayRef<NamedAttribute> attrs) {
1152
+ SmallVector<int64_t > staticIndices;
1153
+ SmallVector<Value> dynamicIndices;
1154
+ dispatchIndexOpFoldResults (indices, dynamicIndices, staticIndices);
1155
+ result.addAttributes (attrs);
1156
+ build (b, result, resultType, tensor, dynamicIndices,
1157
+ b.getDenseI64ArrayAttr (staticIndices));
1158
+ }
1159
+
1160
+ // Build an ExtractOp with dynamic indexes and inferred result type.
1161
+ void ExtractOp::build (OpBuilder &b, OperationState &result, Type resultType,
1162
+ Value tensor, ValueRange indices,
1163
+ ArrayRef<NamedAttribute> attrs) {
1164
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4 >(
1165
+ llvm::map_range (indices, [](Value v) -> OpFoldResult { return v; }));
1166
+ build (b, result, resultType, tensor, indicesValues, attrs);
1167
+ }
1168
+
1169
+ // Build an ExtractOp with dynamic indexes.
1170
+ void ExtractOp::build (OpBuilder &b, OperationState &result, Value tensor,
1171
+ ValueRange indices, ArrayRef<NamedAttribute> attrs) {
1172
+ Type resultType = llvm::cast<TensorType>(tensor.getType ()).getElementType ();
1173
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4 >(
1174
+ llvm::map_range (indices, [](Value v) -> OpFoldResult { return v; }));
1175
+ build (b, result, resultType, tensor, indicesValues, attrs);
1176
+ }
1177
+
1123
1178
LogicalResult ExtractOp::verify () {
1124
1179
// Verify the # indices match if we have a ranked type.
1125
- auto tensorType = llvm::cast<RankedTensorType>( getTensor (). getType ());
1126
- if (tensorType. getRank () != static_cast < int64_t >( getIndices (). size ( )))
1180
+ if ( failed ( checkTensorRankMatchIndices ( getTensor (), getIndices (),
1181
+ getStaticIndices () )))
1127
1182
return emitOpError (" incorrect number of indices for extract_element" );
1128
1183
return success ();
1129
1184
}
@@ -1137,12 +1192,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1137
1192
1138
1193
// Collect the constant indices into the tensor.
1139
1194
SmallVector<uint64_t , 8 > indices;
1140
- for (Attribute indice : adaptor.getIndices ()) {
1141
- if (!indice || !llvm::isa<IntegerAttr>(indice))
1142
- return {};
1143
- indices.push_back (llvm::cast<IntegerAttr>(indice).getInt ());
1195
+ auto dynamicIndicesIt = adaptor.getIndices ().begin ();
1196
+ for (int64_t i : getStaticIndices ()) {
1197
+ if (i != ShapedType::kDynamic ) {
1198
+ indices.push_back (i);
1199
+ } else {
1200
+ Attribute indice = *dynamicIndicesIt;
1201
+ if (!indice || !llvm::isa<IntegerAttr>(indice))
1202
+ return {};
1203
+ indices.push_back (llvm::cast<IntegerAttr>(indice).getInt ());
1204
+ dynamicIndicesIt++;
1205
+ }
1144
1206
}
1145
-
1146
1207
// Fold extract(from_elements(...)).
1147
1208
if (auto fromElementsOp = getTensor ().getDefiningOp <FromElementsOp>()) {
1148
1209
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType ());
@@ -1354,10 +1415,48 @@ void InsertOp::getAsmResultNames(
1354
1415
setNameFn (getResult (), " inserted" );
1355
1416
}
1356
1417
1418
+ // Build an ExtractOp with mixed static and dynamic indexes.
1419
+ void InsertOp::build (OpBuilder &b, OperationState &result, Value scalar,
1420
+ Value dest, ArrayRef<OpFoldResult> indices,
1421
+ ArrayRef<NamedAttribute> attrs) {
1422
+ build (b, result, dest.getType (), scalar, dest, indices, attrs);
1423
+ }
1424
+
1425
+ // Build an InsertOp with mixed static, dynamic indexes and inferred result
1426
+ // Type.
1427
+ void InsertOp::build (OpBuilder &b, OperationState &result, Type resultType,
1428
+ Value scalar, Value dest, ArrayRef<OpFoldResult> indices,
1429
+ ArrayRef<NamedAttribute> attrs) {
1430
+ SmallVector<int64_t > staticIndices;
1431
+ SmallVector<Value> dynamicIndices;
1432
+ dispatchIndexOpFoldResults (indices, dynamicIndices, staticIndices);
1433
+ result.addAttributes (attrs);
1434
+ build (b, result, resultType, scalar, dest, dynamicIndices,
1435
+ b.getDenseI64ArrayAttr (staticIndices));
1436
+ }
1437
+
1438
+ // Build an ExtractOp with dynamic indexes and inferred result type.
1439
+ void InsertOp::build (OpBuilder &b, OperationState &result, Type resultType,
1440
+ Value scalar, Value dest, ValueRange indices,
1441
+ ArrayRef<NamedAttribute> attrs) {
1442
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4 >(
1443
+ llvm::map_range (indices, [](Value v) -> OpFoldResult { return v; }));
1444
+ build (b, result, resultType, scalar, dest, indicesValues, attrs);
1445
+ }
1446
+
1447
+ // Build an InsertOp with dynamic indexes.
1448
+ void InsertOp::build (OpBuilder &b, OperationState &result, Value scalar,
1449
+ Value dest, ValueRange indices,
1450
+ ArrayRef<NamedAttribute> attrs) {
1451
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4 >(
1452
+ llvm::map_range (indices, [](Value v) -> OpFoldResult { return v; }));
1453
+ build (b, result, dest.getType (), scalar, dest, indicesValues, attrs);
1454
+ }
1455
+
1357
1456
LogicalResult InsertOp::verify () {
1358
1457
// Verify the # indices match if we have a ranked type.
1359
- auto destType = llvm::cast<RankedTensorType>( getDest (). getType ());
1360
- if (destType. getRank () != static_cast < int64_t >( getIndices (). size ( )))
1458
+ if ( failed ( checkTensorRankMatchIndices ( getDest (), getIndices (),
1459
+ getStaticIndices () )))
1361
1460
return emitOpError (" incorrect number of indices" );
1362
1461
return success ();
1363
1462
}
0 commit comments