Skip to content

Commit f9cbef9

Browse files
committed
improvements
1 parent 622ae94 commit f9cbef9

File tree

2 files changed

+98
-21
lines changed

2 files changed

+98
-21
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,45 +2385,49 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
23852385
return success();
23862386
}
23872387

2388+
2389+
/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible.
2390+
///
2391+
/// Example:
2392+
/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
2393+
/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
2394+
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2395+
///
2396+
/// becomes
2397+
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
23882398
static LogicalResult
23892399
rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
23902400
PatternRewriter &rewriter) {
23912401

2392-
mlir::OperandRange elements = fromElementsOp.getElements();
2393-
const size_t nbElements = elements.size();
2394-
assert(nbElements > 0 && "must be at least one element");
2395-
2396-
// https://en.wikipedia.org/wiki/List_of_prime_numbers
2397-
const int prime = 5387;
2398-
bool pseudoRandomOrder = nbElements < prime;
2399-
2402+
// The common source of vector.extract operations (if one exists), as well
2403+
// as its shape and rank. Set in the first iteration of the loop over the
2404+
// operands of `fromElementsOp`.
24002405
Value source;
24012406
ArrayRef<int64_t> shape;
2402-
for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) {
2407+
int64_t rank;
24032408

2404-
// Rather than iterating through the elements in ascending order, we might
2405-
// be able to exit quickly if we go through in pseudo-random order. Use
2406-
// fact that (i * p) % a is a bijection for i in [0, a) if p is prime and
2407-
// a < p.
2408-
int currentIndex =
2409-
pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements;
2410-
Value element = elements[currentIndex];
2409+
for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
24112410

2412-
// From an extract on the same source as the other elements.
2411+
// Check that the element is defined by an extract operation, and that
2412+
// the extract is on the same vector as all preceding elements.
24132413
auto extractOp =
24142414
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
24152415
if (!extractOp)
24162416
return failure();
24172417
Value currentSource = extractOp.getVector();
2418-
if (!source) {
2418+
if (index == 0) {
24192419
source = currentSource;
24202420
shape = extractOp.getSourceVectorType().getShape();
2421+
rank = shape.size();
24212422
} else if (currentSource != source) {
24222423
return failure();
24232424
}
24242425

2426+
// Check that the (linearized) index of extraction is the same as the index
2427+
// in the result of `fromElementsOp`.
24252428
ArrayRef<int64_t> position = extractOp.getStaticPosition();
2426-
assert(position.size() == shape.size());
2429+
if (position.size() != rank)
2430+
return failure();
24272431

24282432
int64_t stride{1};
24292433
int64_t offset{0};
@@ -2434,18 +2438,18 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
24342438
offset += pos * stride;
24352439
stride *= size;
24362440
}
2437-
if (offset != currentIndex)
2441+
if (offset != index)
24382442
return failure();
24392443
}
24402444

2441-
// Can replace with a shape_cast.
24422445
rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
24432446
fromElementsOp.getType(), source);
24442447
}
24452448

24462449
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
24472450
MLIRContext *context) {
24482451
results.add(rewriteFromElementsAsSplat);
2452+
results.add(rewriteFromElementsAsShapeCast);
24492453
}
24502454

24512455
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
// This file contains some tests of folding/canonicalizing vector.from_elements
44

5+
///===----------------------------------------------===//
6+
/// Tests of `rewriteFromElementsAsSplat`
7+
///===----------------------------------------------===//
8+
59
// CHECK-LABEL: func @extract_scalar_from_from_elements(
610
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
711
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
@@ -70,3 +74,72 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
7074
}
7175

7276
// -----
77+
78+
79+
///===----------------------------------------------===//
80+
/// Tests of `rewriteFromElementsAsShapeCast`
81+
///===----------------------------------------------===//
82+
83+
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
84+
// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>)
85+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
86+
// CHECK: return %[[shape_cast]] : vector<2xi8>
87+
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
88+
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
89+
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
90+
%4 = vector.from_elements %0, %1 : vector<2xi8>
91+
return %4 : vector<2xi8>
92+
}
93+
94+
// -----
95+
96+
// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
97+
// CHECK-SAME: %[[a:.*]]: vector<8xi8>)
98+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
99+
// CHECK: return %[[shape_cast]] : vector<2x2x2xi8>
100+
func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
101+
%0 = vector.extract %arg0[0] : i8 from vector<8xi8>
102+
%1 = vector.extract %arg0[1] : i8 from vector<8xi8>
103+
%2 = vector.extract %arg0[2] : i8 from vector<8xi8>
104+
%3 = vector.extract %arg0[3] : i8 from vector<8xi8>
105+
%4 = vector.extract %arg0[4] : i8 from vector<8xi8>
106+
%5 = vector.extract %arg0[5] : i8 from vector<8xi8>
107+
%6 = vector.extract %arg0[6] : i8 from vector<8xi8>
108+
%7 = vector.extract %arg0[7] : i8 from vector<8xi8>
109+
%8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
110+
return %8 : vector<2x2x2xi8>
111+
}
112+
113+
// -----
114+
115+
// The extracted elements are recombined into a single vector, but in a new order.
116+
// CHECK-LABEL: func @negative_nonascending_order(
117+
// CHECK-NOT: shape_cast
118+
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
119+
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
120+
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
121+
%2 = vector.from_elements %0, %1 : vector<2xi8>
122+
return %2 : vector<2xi8>
123+
}
124+
125+
// -----
126+
127+
// CHECK-LABEL: func @negative_nonstatic_extract(
128+
// CHECK-NOT: shape_cast
129+
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
130+
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
131+
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
132+
%2 = vector.from_elements %0, %1 : vector<2xi8>
133+
return %2 : vector<2xi8>
134+
}
135+
136+
// -----
137+
138+
// CHECK-LABEL: func @negative_different_sources(
139+
// CHECK-NOT: shape_cast
140+
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
141+
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
142+
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
143+
%2 = vector.from_elements %0, %1 : vector<2xi8>
144+
return %2 : vector<2xi8>
145+
}

0 commit comments

Comments
 (0)