Skip to content

[mli][vector] canonicalize vector.from_elements from ascending extracts #139819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 2, 2025
Merged
97 changes: 97 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -2385,9 +2386,105 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}

/// Rewrite vector.from_elements as vector.shape_cast, if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ultra nit] "if possible" is implicit and "less is more" :)

Suggested change
/// Rewrite vector.from_elements as vector.shape_cast, if possible.
/// Rewrite vector.from_elements as vector.shape_cast.

///
/// Example:
/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
///
/// becomes
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
///
/// The requirements for this to be valid are
/// i) source and from_elements result have the same number of elements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Otherwise it's not clear what source and from_elements are. Perhaps there's better way to clarify 🤔

Suggested change
/// i) source and from_elements result have the same number of elements,
/// i) vector.extract and vector.from_elements result have the same number of elements,

/// ii) all elements are extracted from the same vector (%source),
/// iii) the elements are extracted in ascending order.
///
/// It might be possible to rewrite vector.from_elements as a single
/// vector.extract if (i) is not satisifed, or in some cases as a
/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied,
/// this is left for future consideration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we already have quite a few TODOs/FIXMEs that basically mean "let’s look at this later." But the phrasing "It might be possible…” feels particularly vague here - I’d suggest omitting it unless we can be more specific.

If we do want to leave a note, maybe something like:

“Consider extending to use a single vector.extract when (i) does not hold.”

Also, just a general thought: extending this pattern could quickly become quite complex. If we're seeing bad code that would benefit from such a complicated rewrite, it might be worth checking whether the producer of that code could be improved instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'll spend a bit of time trying to canonicalize directly to vector.extract, I don't think it'll be significantly more complex.

class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {

// The source of the first element. This is initialized in the first
// iteration of the loop over elements.
TypedValue<VectorType> firstElementSource;

for (auto [insertIndex, element] :
llvm::enumerate(fromElements.getElements())) {

// Check that the element is from a vector.extract operation.
auto extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
}

// Check condition (i) on the first element. As we will check that all
// elements have the same source, we don't need to check condition (i) for
// any other elements.
if (insertIndex == 0) {
firstElementSource = extractOp.getVector();
if (static_cast<size_t>(
firstElementSource.getType().getNumElements()) !=
fromElements.getType().getNumElements()) {
return rewriter.notifyMatchFailure(fromElements,
"number of elements differ");
}
}

// Check condition (ii), by checking that all elements have same source as
// the first element.
Value currentSource = extractOp.getVector();
if (currentSource != firstElementSource) {
return rewriter.notifyMatchFailure(fromElements,
"element from different vector");
}

// Check condition (iii).
// First, get the index that the element is extracted from.
int64_t extractIndex{0};
int64_t stride{1};
ArrayRef<int64_t> position = extractOp.getStaticPosition();
assert(position.size() ==
static_cast<size_t>(firstElementSource.getType().getRank()) &&
"scalar extract must have full rank position");
for (auto [pos, size] :
llvm::zip(llvm::reverse(position),
llvm::reverse(firstElementSource.getType().getShape()))) {
if (pos == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (dynamic order)");
}
extractIndex += pos * stride;
stride *= size;
}

// Second, check that the index of extraction from source and insertion in
// from_elements are the same.
if (extractIndex != static_cast<int64_t>(insertIndex)) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (static order)");
}
}

rewriter.replaceOpWithNewOp<ShapeCastOp>(
fromElements, fromElements.getType(), firstElementSource);
return success();
}
};

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
results.add<FromElementsToShapCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 0 additions & 69 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
Expand Down
169 changes: 169 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file contains some tests of folding/canonicalizing vector.from_elements

///===----------------------------------------------===//
/// Tests of `rewriteFromElementsAsSplat`
///===----------------------------------------------===//
Comment on lines +5 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section was copied, right? Could you add a note in the summary so that it's easy to track the history?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, copied. I added a comment to the PR summary, I assume that's where you meant?


// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

///===----------------------------------------------===//
/// Tests of `FromElementsToShapeCast`
///===----------------------------------------------===//

// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%4 = vector.from_elements %0, %1 : vector<2xi8>
return %4 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
// CHECK-SAME: %[[A:.*]]: vector<8xi8>)
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8>
func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
%0 = vector.extract %arg0[0] : i8 from vector<8xi8>
%1 = vector.extract %arg0[1] : i8 from vector<8xi8>
%2 = vector.extract %arg0[2] : i8 from vector<8xi8>
%3 = vector.extract %arg0[3] : i8 from vector<8xi8>
%4 = vector.extract %arg0[4] : i8 from vector<8xi8>
%5 = vector.extract %arg0[5] : i8 from vector<8xi8>
%6 = vector.extract %arg0[6] : i8 from vector<8xi8>
%7 = vector.extract %arg0[7] : i8 from vector<8xi8>
%8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
return %8 : vector<2x2x2xi8>
}

// -----

// The extracted elements are recombined into a single vector, but in a new order.
// CHECK-LABEL: func @negative_nonascending_order(
// CHECK-NOT: shape_cast
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_nonstatic_extract(
// CHECK-NOT: shape_cast
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_different_sources(
// CHECK-NOT: shape_cast
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_source_too_large(
// CHECK-NOT: shape_cast
func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// The inserted elements are are a subset of the extracted elements.
// [0, 1, 2] -> [1, 1, 2]
// CHECK-LABEL: func @negative_nobijection_order(
// CHECK-NOT: shape_cast
func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %0, %1 : vector<3xi8>
return %2 : vector<3xi8>
}