Skip to content

Commit 8c2ea14

Browse files
committed
[mlir][vector] Fold scalar vector.extract of non-splat n-D constants
Add a new pattern to fold `vector.extract` over n-D constants that extract scalars. The previous code handled ND splat constants only. The new pattern is conservative and does handle sub-vector constants. This is to aid the `arith::EmulateWideInt` pass which emits a lot of 2-element vector constants. Reviewed By: Mogball, dcaballe Differential Revision: https://reviews.llvm.org/D133742
1 parent 49f3f97 commit 8c2ea14

File tree

2 files changed

+125
-13
lines changed

2 files changed

+125
-13
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,33 +1534,94 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
15341534
};
15351535

15361536
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1537-
class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
1537+
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
15381538
public:
15391539
using OpRewritePattern::OpRewritePattern;
15401540

15411541
LogicalResult matchAndRewrite(ExtractOp extractOp,
15421542
PatternRewriter &rewriter) const override {
1543-
// Return if 'extractStridedSliceOp' operand is not defined by a
1543+
// Return if 'ExtractOp' operand is not defined by a splat vector
15441544
// ConstantOp.
1545-
auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
1546-
if (!constantOp)
1545+
Value sourceVector = extractOp.getVector();
1546+
Attribute vectorCst;
1547+
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
15471548
return failure();
1548-
auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
1549-
if (!dense)
1549+
auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
1550+
if (!splat)
15501551
return failure();
1551-
Attribute newAttr = dense.getSplatValue<Attribute>();
1552+
Attribute newAttr = splat.getSplatValue<Attribute>();
15521553
if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
15531554
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
15541555
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
15551556
return success();
15561557
}
15571558
};
15581559

1560+
// Pattern to rewrite a ExtractOp(vector<...xT> ConstantOp)[...] -> ConstantOp,
1561+
// where the position array specifies a scalar element.
1562+
class ExtractOpScalarVectorConstantFolder final
1563+
: public OpRewritePattern<ExtractOp> {
1564+
public:
1565+
using OpRewritePattern::OpRewritePattern;
1566+
1567+
LogicalResult matchAndRewrite(ExtractOp extractOp,
1568+
PatternRewriter &rewriter) const override {
1569+
// Return if 'ExtractOp' operand is not defined by a compatible vector
1570+
// ConstantOp.
1571+
Value sourceVector = extractOp.getVector();
1572+
Attribute vectorCst;
1573+
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1574+
return failure();
1575+
1576+
auto vecTy = sourceVector.getType().cast<VectorType>();
1577+
Type elemTy = vecTy.getElementType();
1578+
ArrayAttr positions = extractOp.getPosition();
1579+
if (vecTy.isScalable())
1580+
return failure();
1581+
// Do not allow extracting sub-vectors to limit the size of the generated
1582+
// constants.
1583+
if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
1584+
return failure();
1585+
// TODO: Handle more element types, e.g., complex values.
1586+
if (!elemTy.isIntOrIndexOrFloat())
1587+
return failure();
1588+
1589+
// The splat case is handled by `ExtractOpSplatConstantFolder`.
1590+
auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
1591+
if (!dense || dense.isSplat())
1592+
return failure();
1593+
1594+
// Calculate the flattened position.
1595+
int64_t elemPosition = 0;
1596+
int64_t innerElems = 1;
1597+
for (auto [dimSize, positionInDim] :
1598+
llvm::reverse(llvm::zip(vecTy.getShape(), positions))) {
1599+
int64_t positionVal = positionInDim.cast<IntegerAttr>().getInt();
1600+
elemPosition += positionVal * innerElems;
1601+
innerElems *= dimSize;
1602+
}
1603+
1604+
Attribute newAttr;
1605+
if (vecTy.getElementType().isIntOrIndex()) {
1606+
auto values = to_vector(dense.getValues<APInt>());
1607+
newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]);
1608+
} else if (vecTy.getElementType().isa<FloatType>()) {
1609+
auto values = to_vector(dense.getValues<APFloat>());
1610+
newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]);
1611+
}
1612+
assert(newAttr && "Unhandled case");
1613+
1614+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1615+
return success();
1616+
}
1617+
};
1618+
15591619
} // namespace
15601620

15611621
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
15621622
MLIRContext *context) {
1563-
results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
1623+
results.add<ExtractOpSplatConstantFolder, ExtractOpScalarVectorConstantFolder,
1624+
ExtractOpFromBroadcast>(context);
15641625
}
15651626

15661627
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,11 +1367,11 @@ func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
13671367

13681368
// -----
13691369

1370-
// CHECK-LABEL: extract_constant
1371-
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
1372-
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
1373-
// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
1374-
func.func @extract_constant() -> (vector<7xf32>, i32) {
1370+
// CHECK-LABEL: func.func @extract_splat_constant
1371+
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
1372+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
1373+
// CHECK-NEXT: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
1374+
func.func @extract_splat_constant() -> (vector<7xf32>, i32) {
13751375
%cst = arith.constant dense<2.000000e+00> : vector<29x7xf32>
13761376
%cst_1 = arith.constant dense<1> : vector<4x37x9xi32>
13771377
%0 = vector.extract %cst[2] : vector<29x7xf32>
@@ -1381,6 +1381,57 @@ func.func @extract_constant() -> (vector<7xf32>, i32) {
13811381

13821382
// -----
13831383

1384+
// CHECK-LABEL: func.func @extract_1d_constant
1385+
// CHECK-DAG: %[[I32CST:.*]] = arith.constant 3 : i32
1386+
// CHECK-DAG: %[[IDXCST:.*]] = arith.constant 1 : index
1387+
// CHECK-DAG: %[[F32CST:.*]] = arith.constant 2.000000e+00 : f32
1388+
// CHECK-NEXT: return %[[I32CST]], %[[IDXCST]], %[[F32CST]] : i32, index, f32
1389+
func.func @extract_1d_constant() -> (i32, index, f32) {
1390+
%icst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
1391+
%e = vector.extract %icst[2] : vector<4xi32>
1392+
%idx_cst = arith.constant dense<[0, 1, 2]> : vector<3xindex>
1393+
%f = vector.extract %idx_cst[1] : vector<3xindex>
1394+
%fcst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<3xf32>
1395+
%g = vector.extract %fcst[0] : vector<3xf32>
1396+
return %e, %f, %g : i32, index, f32
1397+
}
1398+
1399+
// -----
1400+
1401+
// CHECK-LABEL: func.func @extract_2d_constant
1402+
// CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
1403+
// CHECK-DAG: %[[BCST:.*]] = arith.constant 2 : i32
1404+
// CHECK-DAG: %[[CCST:.*]] = arith.constant 3 : i32
1405+
// CHECK-DAG: %[[DCST:.*]] = arith.constant 5 : i32
1406+
// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32
1407+
func.func @extract_2d_constant() -> (i32, i32, i32, i32) {
1408+
%cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
1409+
%a = vector.extract %cst[0, 0] : vector<2x3xi32>
1410+
%b = vector.extract %cst[0, 2] : vector<2x3xi32>
1411+
%c = vector.extract %cst[1, 0] : vector<2x3xi32>
1412+
%d = vector.extract %cst[1, 2] : vector<2x3xi32>
1413+
return %a, %b, %c, %d : i32, i32, i32, i32
1414+
}
1415+
1416+
// -----
1417+
1418+
// CHECK-LABEL: func.func @extract_3d_constant
1419+
// CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
1420+
// CHECK-DAG: %[[BCST:.*]] = arith.constant 1 : i32
1421+
// CHECK-DAG: %[[CCST:.*]] = arith.constant 9 : i32
1422+
// CHECK-DAG: %[[DCST:.*]] = arith.constant 10 : i32
1423+
// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32
1424+
func.func @extract_3d_constant() -> (i32, i32, i32, i32) {
1425+
%cst = arith.constant dense<[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]> : vector<2x3x2xi32>
1426+
%a = vector.extract %cst[0, 0, 0] : vector<2x3x2xi32>
1427+
%b = vector.extract %cst[0, 0, 1] : vector<2x3x2xi32>
1428+
%c = vector.extract %cst[1, 1, 1] : vector<2x3x2xi32>
1429+
%d = vector.extract %cst[1, 2, 0] : vector<2x3x2xi32>
1430+
return %a, %b, %c, %d : i32, i32, i32, i32
1431+
}
1432+
1433+
// -----
1434+
13841435
// CHECK-LABEL: extract_extract_strided
13851436
// CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
13861437
// CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>

0 commit comments

Comments
 (0)