Skip to content

[mli][vector] canonicalize vector.from_elements from ascending extracts #139819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -2385,9 +2386,98 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}

/// Rewrite vector.from_elements as vector.shape_cast, if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ultra nit] "if possible" is implicit and "less is more" :)

Suggested change
/// Rewrite vector.from_elements as vector.shape_cast, if possible.
/// Rewrite vector.from_elements as vector.shape_cast.

///
/// Example:
/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
///
/// becomes
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
///
/// The requirements for this to be valid are
/// i) all elements are extracted from the same vector (source),
/// ii) source and from_elements result have the same number of elements,
/// iii) the elements are extracted in ascending order.
///
/// It might be possible to rewrite vector.from_elements as a single
/// vector.extract if (ii) is not satisifed, or in some cases as a
/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
/// this is left for future consideration.
class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {

mlir::OperandRange elements = fromElements.getElements();
assert(!elements.empty() && "must be at least 1 element");
Value firstElement = elements.front();

ExtractOp extractOp =
dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(
fromElements, "first element not from vector.extract");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we check the first element separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I flip-flopped between having a conditional "if (index == 0) { // do the one off check }" inside the loop and doing it before the loop. But I've gone back to doing it in the loop now

VectorType sourceType = extractOp.getSourceVectorType();
Value source = extractOp.getVector();

// Check condition (ii).
if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
return rewriter.notifyMatchFailure(fromElements,
"number of elements differ");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I would "rebrand" this as "Condition (i)" (it's the first condition to be checked) and move it all the way to the top - it feels like a fairly high level condition that deserves a special place :)


for (auto [indexMinusOne, element] :
llvm::enumerate(elements.drop_front(1))) {

extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
}
Value currentSource = extractOp.getVector();
// Check condition (i).
if (currentSource != source) {
return rewriter.notifyMatchFailure(fromElements,
"element from different vector");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To me all of this is checking "condition (i)" and everything that's left is checking "condition (ii)". I would just move the comments and basically split the loop body into 2 blocks.


ArrayRef<int64_t> position = extractOp.getStaticPosition();
assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
"scalar extract must have full rank position");
int64_t stride{1};
int64_t offset{0};
for (auto [pos, size] : llvm::zip(llvm::reverse(position),
llvm::reverse(sourceType.getShape()))) {
if (pos == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (dynamic order)");
}
offset += pos * stride;
stride *= size;
}
// Check condition (iii).
if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (static order)");
}
}

rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
fromElements.getType(), source);
return success();
}
};

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
results.add<FromElementsToShapCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 0 additions & 69 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
Expand Down
155 changes: 155 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file contains some tests of folding/canonicalizing vector.from_elements

///===----------------------------------------------===//
/// Tests of `rewriteFromElementsAsSplat`
///===----------------------------------------------===//
Comment on lines +5 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section was copied, right? Could you add a note in the summary so that it's easy to track the history?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, copied. I added a comment to the PR summary, I assume that's where you meant?


// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

///===----------------------------------------------===//
/// Tests of `FromElementsToShapeCast`
///===----------------------------------------------===//

// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>)
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[shape_cast]] : vector<2xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you use caps for LIT variables? That's much more common. And, IMHO, easier to parse 😅

I suspect that you wanted to maintain consistency with the tests for rewriteFromElementsAsSplat? I would just update those as well (fortunately, there arent' that many LIT vars there)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easier, but still hard IMO 😆
Done

func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%4 = vector.from_elements %0, %1 : vector<2xi8>
return %4 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
// CHECK-SAME: %[[a:.*]]: vector<8xi8>)
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
// CHECK: return %[[shape_cast]] : vector<2x2x2xi8>
func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
%0 = vector.extract %arg0[0] : i8 from vector<8xi8>
%1 = vector.extract %arg0[1] : i8 from vector<8xi8>
%2 = vector.extract %arg0[2] : i8 from vector<8xi8>
%3 = vector.extract %arg0[3] : i8 from vector<8xi8>
%4 = vector.extract %arg0[4] : i8 from vector<8xi8>
%5 = vector.extract %arg0[5] : i8 from vector<8xi8>
%6 = vector.extract %arg0[6] : i8 from vector<8xi8>
%7 = vector.extract %arg0[7] : i8 from vector<8xi8>
%8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
return %8 : vector<2x2x2xi8>
}

// -----

// The extracted elements are recombined into a single vector, but in a new order.
// CHECK-LABEL: func @negative_nonascending_order(
// CHECK-NOT: shape_cast
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_nonstatic_extract(
// CHECK-NOT: shape_cast
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_different_sources(
// CHECK-NOT: shape_cast
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_source_too_large(
// CHECK-NOT: shape_cast
func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}