Skip to content

Commit 622ae94

Browse files
committed
first commit
1 parent a3d2b7e commit 622ae94

File tree

3 files changed

+130
-69
lines changed

3 files changed

+130
-69
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,64 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
23852385
return success();
23862386
}
23872387

2388+
static LogicalResult
2389+
rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
2390+
PatternRewriter &rewriter) {
2391+
2392+
mlir::OperandRange elements = fromElementsOp.getElements();
2393+
const size_t nbElements = elements.size();
2394+
assert(nbElements > 0 && "must be at least one element");
2395+
2396+
// https://en.wikipedia.org/wiki/List_of_prime_numbers
2397+
const int prime = 5387;
2398+
bool pseudoRandomOrder = nbElements < prime;
2399+
2400+
Value source;
2401+
ArrayRef<int64_t> shape;
2402+
for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) {
2403+
2404+
// Rather than iterating through the elements in ascending order, we might
2405+
// be able to exit quickly if we go through in pseudo-random order. Use
2406+
// fact that (i * p) % a is a bijection for i in [0, a) if p is prime and
2407+
// a < p.
2408+
int currentIndex =
2409+
pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements;
2410+
Value element = elements[currentIndex];
2411+
2412+
// From an extract on the same source as the other elements.
2413+
auto extractOp =
2414+
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
2415+
if (!extractOp)
2416+
return failure();
2417+
Value currentSource = extractOp.getVector();
2418+
if (!source) {
2419+
source = currentSource;
2420+
shape = extractOp.getSourceVectorType().getShape();
2421+
} else if (currentSource != source) {
2422+
return failure();
2423+
}
2424+
2425+
ArrayRef<int64_t> position = extractOp.getStaticPosition();
2426+
assert(position.size() == shape.size());
2427+
2428+
int64_t stride{1};
2429+
int64_t offset{0};
2430+
for (auto [pos, size] :
2431+
llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
2432+
if (pos == ShapedType::kDynamic)
2433+
return failure();
2434+
offset += pos * stride;
2435+
stride *= size;
2436+
}
2437+
if (offset != currentIndex)
2438+
return failure();
2439+
}
2440+
2441+
// Can replace with a shape_cast.
2442+
rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
2443+
fromElementsOp.getType(), source);
2444+
}
2445+
23882446
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
23892447
MLIRContext *context) {
23902448
results.add(rewriteFromElementsAsSplat);

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
29522952

29532953
// -----
29542954

2955-
// CHECK-LABEL: func @extract_scalar_from_from_elements(
2956-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2957-
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
2958-
// Extract from 0D.
2959-
%0 = vector.from_elements %a : vector<f32>
2960-
%1 = vector.extract %0[] : f32 from vector<f32>
2961-
2962-
// Extract from 1D.
2963-
%2 = vector.from_elements %a : vector<1xf32>
2964-
%3 = vector.extract %2[0] : f32 from vector<1xf32>
2965-
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
2966-
%5 = vector.extract %4[4] : f32 from vector<5xf32>
2967-
2968-
// Extract from 2D.
2969-
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
2970-
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
2971-
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
2972-
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
2973-
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
2974-
2975-
// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
2976-
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
2977-
}
2978-
2979-
// -----
2980-
2981-
// CHECK-LABEL: func @extract_1d_from_from_elements(
2982-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2983-
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
2984-
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
2985-
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
2986-
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
2987-
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
2988-
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
2989-
// CHECK: return %[[splat1]], %[[splat2]]
2990-
return %1, %2 : vector<3xf32>, vector<3xf32>
2991-
}
2992-
2993-
// -----
2994-
2995-
// CHECK-LABEL: func @extract_2d_from_from_elements(
2996-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2997-
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
2998-
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
2999-
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
3000-
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
3001-
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
3002-
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
3003-
// CHECK: return %[[splat1]], %[[splat2]]
3004-
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
3005-
}
3006-
3007-
// -----
3008-
3009-
// CHECK-LABEL: func @from_elements_to_splat(
3010-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
3011-
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
3012-
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
3013-
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
3014-
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
3015-
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
3016-
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
3017-
%2 = vector.from_elements %a : vector<f32>
3018-
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
3019-
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
3020-
}
3021-
3022-
// -----
3023-
30242955
// CHECK-LABEL: func @vector_insert_const_regression(
30252956
// CHECK: llvm.mlir.undef
30262957
// CHECK: vector.insert
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
2+
3+
// This file contains some tests of folding/canonicalizing vector.from_elements
4+
5+
// CHECK-LABEL: func @extract_scalar_from_from_elements(
6+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
7+
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
8+
// Extract from 0D.
9+
%0 = vector.from_elements %a : vector<f32>
10+
%1 = vector.extract %0[] : f32 from vector<f32>
11+
12+
// Extract from 1D.
13+
%2 = vector.from_elements %a : vector<1xf32>
14+
%3 = vector.extract %2[0] : f32 from vector<1xf32>
15+
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
16+
%5 = vector.extract %4[4] : f32 from vector<5xf32>
17+
18+
// Extract from 2D.
19+
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
20+
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
21+
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
22+
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
23+
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
24+
25+
// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
26+
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
27+
}
28+
29+
// -----
30+
31+
// CHECK-LABEL: func @extract_1d_from_from_elements(
32+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
33+
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
34+
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
35+
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
36+
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
37+
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
38+
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
39+
// CHECK: return %[[splat1]], %[[splat2]]
40+
return %1, %2 : vector<3xf32>, vector<3xf32>
41+
}
42+
43+
// -----
44+
45+
// CHECK-LABEL: func @extract_2d_from_from_elements(
46+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
47+
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
48+
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
49+
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
50+
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
51+
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
52+
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
53+
// CHECK: return %[[splat1]], %[[splat2]]
54+
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
55+
}
56+
57+
// -----
58+
59+
// CHECK-LABEL: func @from_elements_to_splat(
60+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
61+
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
62+
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
63+
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
64+
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
65+
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
66+
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
67+
%2 = vector.from_elements %a : vector<f32>
68+
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
69+
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
70+
}
71+
72+
// -----

0 commit comments

Comments
 (0)