Skip to content

Commit c8b7a74

Browse files
committed
[mlir][vector] Propagate vector.extract through elementwise ops
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`. Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
1 parent 6a030b3 commit c8b7a74

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,47 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
22372237
}
22382238
};
22392239

2240+
/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
2241+
class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
2242+
public:
2243+
using OpRewritePattern::OpRewritePattern;
2244+
2245+
LogicalResult matchAndRewrite(ExtractOp op,
2246+
PatternRewriter &rewriter) const override {
2247+
Operation *eltwise = op.getVector().getDefiningOp();
2248+
2249+
// Elementwise op with single result and `extract` is single user.
2250+
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
2251+
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
2252+
return failure();
2253+
2254+
// Arguments and result types must match.
2255+
if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
2256+
eltwise->getResultTypes())))
2257+
return failure();
2258+
2259+
Type dstType = op.getType();
2260+
2261+
OpBuilder::InsertionGuard g(rewriter);
2262+
rewriter.setInsertionPoint(eltwise);
2263+
2264+
IRMapping mapping;
2265+
Location loc = eltwise->getLoc();
2266+
for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
2267+
Value newArg =
2268+
rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
2269+
mapping.map(arg, newArg);
2270+
}
2271+
2272+
Operation *newEltwise = rewriter.clone(*eltwise, mapping);
2273+
newEltwise->getResult(0).setType(dstType);
2274+
2275+
rewriter.replaceOp(op, newEltwise);
2276+
rewriter.eraseOp(eltwise);
2277+
return success();
2278+
}
2279+
};
2280+
22402281
// Folds extract(shape_cast(..)) into shape_cast when the total element count
22412282
// does not change.
22422283
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2309,7 +2350,8 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23092350

23102351
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23112352
MLIRContext *context) {
2312-
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2353+
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2354+
ExtractOpFromElemetwise>(context);
23132355
results.add(foldExtractFromShapeCastToShapeCast);
23142356
results.add(foldExtractFromFromElements);
23152357
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
244244

245245
// -----
246246

247+
// CHECK-LABEL: @extract_elementwise
248+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
249+
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
250+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
251+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
252+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
253+
// CHECK: return %[[RES]] : f32
254+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
255+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
256+
return %1 : f32
257+
}
258+
259+
// -----
260+
261+
// CHECK-LABEL: @extract_vec_elementwise
262+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
263+
func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
264+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
265+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
266+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
267+
// CHECK: return %[[RES]] : vector<4xf32>
268+
%0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
269+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
270+
return %1 : vector<4xf32>
271+
}
272+
273+
// -----
274+
275+
// CHECK-LABEL: @extract_elementwise_use
276+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
277+
func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
278+
// Dop not propagate extract, as elementwise has other uses
279+
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
280+
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
281+
// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
282+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
283+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
284+
return %1, %0 : f32, vector<4xf32>
285+
}
286+
287+
// -----
288+
247289
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
248290
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
249291
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>

0 commit comments

Comments
 (0)