Skip to content

Commit a8406b3

Browse files
committed
fixup! [mlir][linalg] Split GenericPadOpVectorizationPattern into two patterns
* Incorporate suggestions from Hanhan * Add a negative test to document when vectorization of tensor.insert_slice might fail * Update `@pad_and_insert_slice_dest` that was added in #112504 (this change means that _all_ qualifying `tensor.insert_slice` Ops are vectorized). * Added more tests to demonstrate other cases (e.g. default vs non-default pad value).
1 parent 45318f3 commit a8406b3

File tree

5 files changed

+115
-40
lines changed

5 files changed

+115
-40
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256256
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
257257
RewritePatternSet &patterns) {
258258
linalg::populatePadOpVectorizationPatterns(patterns);
259+
linalg::populateInsertSliceVectorizationPatterns(patterns);
259260
}
260261

261262
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,35 +2514,18 @@ struct PadOpVectorizationWithTransferWritePattern
25142514
}
25152515
};
25162516

2517-
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
2518-
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2519-
/// not supported.
2520-
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2521-
ArrayRef<OpFoldResult> ofrs) {
2522-
SmallVector<Value> result;
2523-
for (auto o : ofrs) {
2524-
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2525-
result.push_back(val);
2526-
} else {
2527-
result.push_back(rewriter.create<arith::ConstantIndexOp>(
2528-
loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt()));
2529-
}
2530-
}
2531-
return result;
2532-
}
2533-
25342517
/// Returns the effective Pad value for the input op, provided it's a scalar.
25352518
///
25362519
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
25372520
/// this Op performs padding, retrieve the padding value provided that it's
25382521
/// a scalar and static/fixed for all the padded values. Returns an empty value
25392522
/// otherwise.
2540-
static Value getStaticPadVl(Operation *op) {
2523+
static Value getStaticPadVal(Operation *op) {
25412524
if (!op)
25422525
return {};
25432526

2544-
// 1. vector.broadcast - return the value that's being broadcast,
2545-
// provided that it's a scalar.
2527+
// 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2528+
// being broadcast, provided that it's a scalar.
25462529
if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
25472530
auto source = bcast.getSource();
25482531
if (llvm::dyn_cast<VectorType>(source.getType()))
@@ -2551,31 +2534,31 @@ static Value getStaticPadVl(Operation *op) {
25512534
return source;
25522535
}
25532536

2554-
// 1. linalg.fill - use the scalar input value that used to fill the output
2537+
// 2. linalg.fill - use the scalar input value that used to fill the output
25552538
// tensor.
25562539
if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
25572540
return fill.getInputs()[0];
25582541
}
25592542

2560-
// 2. tensor.generateOp - can't guarantee the value is fixed without
2543+
// 3. tensor.generateOp - can't guarantee the value is fixed without
25612544
// analysing, bail out.
25622545
if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
25632546
return {};
25642547
}
25652548

2566-
// 3. vector.transfer_write - inspect the input vector that's written from. If
2549+
// 4. vector.transfer_write - inspect the input vector that's written from. If
25672550
// if contains a single value that has been broadcast (e.g. via
25682551
// vector.broadcast), extract it, fail otherwise.
25692552
if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2570-
return getStaticPadVl(xferWrite.getVector().getDefiningOp());
2553+
return getStaticPadVal(xferWrite.getVector().getDefiningOp());
25712554

2572-
// 4. tensor.insert_slice - inspect the destination tensor. If it's larger
2555+
// 5. tensor.insert_slice - inspect the destination tensor. If it's larger
25732556
// than the input tensor, then, provided it's constant, we'll extract the
25742557
// value that was used to generate it (via e.g. linalg.fill), fail otherwise.
25752558
// TODO: Clarify the semantics when the input tensor is larger than the
25762559
// destination.
25772560
if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2578-
return getStaticPadVl(slice.getDest().getDefiningOp());
2561+
return getStaticPadVal(slice.getDest().getDefiningOp());
25792562

25802563
return {};
25812564
}
@@ -2619,7 +2602,7 @@ struct InsertSliceVectorizePattern
26192602
// remains a TODO.
26202603
//
26212604
// When the value is not known and not needed, use 0. Otherwise, bail out.
2622-
Value padValue = getStaticPadVl(sliceOp);
2605+
Value padValue = getStaticPadVal(sliceOp);
26232606
bool isOutOfBoundsRead = !sourceType.hasStaticShape();
26242607

26252608
if (!padValue && isOutOfBoundsRead) {
@@ -2637,6 +2620,7 @@ struct InsertSliceVectorizePattern
26372620
SmallVector<int64_t> vecShape;
26382621
SmallVector<bool> readInBounds;
26392622
SmallVector<bool> writeInBounds;
2623+
size_t rankDiff = resultType.getRank() - sourceType.getRank();
26402624
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
26412625
if (!sourceType.isDynamicDim(i)) {
26422626
vecShape.push_back(sourceType.getDimSize(i));
@@ -2648,7 +2632,9 @@ struct InsertSliceVectorizePattern
26482632
// Source shape is not statically known, but result shape is.
26492633
// Vectorize with size of result shape. This may be larger than the
26502634
// source size.
2651-
vecShape.push_back(resultType.getDimSize(i));
2635+
// FIXME: Using rankDiff implies that the source tensor is inserted at
2636+
// the end of the destination tensor. However, that's not required.
2637+
vecShape.push_back(resultType.getDimSize(rankDiff + i));
26522638
// Read may be out-of-bounds because the result size could be larger
26532639
// than the source size.
26542640
readInBounds.push_back(false);
@@ -2673,8 +2659,8 @@ struct InsertSliceVectorizePattern
26732659
ArrayRef<bool>{readInBounds});
26742660

26752661
// 4. Generate TransferWriteOp.
2676-
auto writeIndices =
2677-
ofrToIndexValues(rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2662+
auto writeIndices = getValueOrCreateConstantIndexOp(
2663+
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
26782664

26792665
// 5. Finalize
26802666
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
@@ -2761,8 +2747,8 @@ struct PadOpVectorizationWithInsertSlicePattern
27612747
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
27622748
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
27632749
// source must fit into the destination at the specified offsets.
2764-
auto writeIndices =
2765-
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2750+
auto writeIndices = getValueOrCreateConstantIndexOp(
2751+
rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
27662752
SmallVector<bool> inBounds(vecRank, true);
27672753
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
27682754
insertOp, read, insertOp.getDest(), writeIndices,

mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ module attributes {transform.with_named_sequence} {
161161

162162
///----------------------------------------------------------------------------------------
163163
/// tensor::PadOp -> tensor::EmptyOp + linalg::FillOp/tensor::GenerateOp + tensor::InsertSliceOp
164-
/// [Pattern: GenericPadOpVectorizationPattern]
164+
/// [Pattern: GenericPadOpVectorizationPattern + InsertSliceVectorizePattern]
165+
/// TODO: Split the test into two, one for each pattern.
165166
///----------------------------------------------------------------------------------------
166167

167168
func.func private @make_vector() -> tensor<12x13xf32>
@@ -174,12 +175,14 @@ func.func private @make_vector() -> tensor<12x13xf32>
174175
// CHECK-NOT: tensor.pad
175176
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
176177
// CHECK-DAG: %[[PAD:.*]] = arith.constant 5.000000e+00 : f32
178+
// CHECK-DAG: %[[PAD_READ:.*]] = arith.constant 0.000000e+00 : f32
177179
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
178180
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[PAD]] : f32) outs(%[[EMPTY]] : tensor<1x12x13xf32>) -> tensor<1x12x13xf32>
179-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
180-
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
181+
// CHECK: %[[READ_1:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
182+
// CHECK: %[[WRITE_1:.*]] = vector.transfer_write %[[READ_1]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
181183
// CHECK: %[[VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
182-
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[VEC]] into %[[WRITE]][0, 0, 0] [1, 12, 13] [1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
184+
// CHECK: %[[READ_2:.*]] = vector.transfer_read %[[VEC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD_READ]] {in_bounds = [true, true]} : tensor<12x13xf32>, vector<12x13xf32>
185+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_2]], %[[WRITE_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<12x13xf32>, tensor<1x12x13xf32>
183186
// CHECK: return %[[RES]] : tensor<1x12x13xf32>
184187

185188
func.func @pad_and_insert_slice_dest(

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,25 @@ module attributes {transform.with_named_sequence} {
253253
transform.yield
254254
}
255255
}
256+
257+
// -----
258+
259+
// With dynamically shaped source, the vectorizer infers the vector size for
260+
// xfer Ops from the destination tensor and, conservatively, assumes
261+
// out-of-bounds accesses. Out-of-bounds accesses require a pad value, but
262+
// that's impossible to recover in this example. Hence the vectorization fails.
263+
264+
func.func @insert_slice_default_pad(%arg0: tensor<1x?x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>, %size: index) -> tensor<9x8x7x1x2x3xf32> {
265+
// expected-error @+1 {{Attempted to vectorize, but failed}}
266+
%res = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, %size, 3][1, 1, 1, 1, 1, 1] : tensor<1x?x3xf32> into tensor<9x8x7x1x2x3xf32>
267+
return %res : tensor<9x8x7x1x2x3xf32>
268+
}
269+
270+
module attributes {transform.with_named_sequence} {
271+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
272+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
273+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
274+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
275+
transform.yield
276+
}
277+
}

mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,17 +1935,80 @@ module attributes {transform.with_named_sequence} {
19351935
/// tensor.insert_slice
19361936
///----------------------------------------------------------------------------------------
19371937

1938-
// CHECK-LABEL: func @insert_slice
1938+
// The pad value for xfer-read is neither needed nor available - use the default (0.0).
1939+
1940+
// CHECK-LABEL: func @insert_static_slice_default_pad
19391941
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x2x3xf32>,
19401942
// CHECK-SAME: %[[ARG_1:.*]]: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
19411943
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
19421944
// CHECK: %[[C0:.*]] = arith.constant 0 : index
19431945
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x2x3xf32>
19441946
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
19451947
// CHECK: return %[[WRITE]] : tensor<9x8x7x1x2x3xf32>
1946-
func.func @insert_slice(%arg0: tensor<1x2x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
1947-
%0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
1948-
return %0 : tensor<9x8x7x1x2x3xf32>
1948+
func.func @insert_static_slice_default_pad(%arg0: tensor<1x2x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
1949+
%res = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
1950+
return %res : tensor<9x8x7x1x2x3xf32>
1951+
}
1952+
1953+
module attributes {transform.with_named_sequence} {
1954+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1955+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1956+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1957+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
1958+
transform.yield
1959+
}
1960+
}
1961+
1962+
// -----
1963+
1964+
// Same as above, but there's a pad value available that should be used instead of the default value.
1965+
1966+
// CHECK-LABEL: func.func @insert_static_slice_non_zero_pad
1967+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x2x3xf32>,
1968+
// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor<9x8x7x1x2x3xf32> {
1969+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
1970+
// CHECK: %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
1971+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7x1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
1972+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{.*}}, %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x2x3xf32>
1973+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[WRITE]]{{.*}} {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
1974+
// CHECK: return %[[RES]] : tensor<9x8x7x1x2x3xf32>
1975+
func.func @insert_static_slice_non_zero_pad(%arg0: tensor<1x2x3xf32>, %pad : f32) -> tensor<9x8x7x1x2x3xf32> {
1976+
%init = tensor.empty() : tensor<9x8x7x1x2x3xf32>
1977+
%fill = linalg.fill ins(%pad : f32) outs(%init : tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32>
1978+
%res = tensor.insert_slice %arg0 into %fill[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
1979+
return %res : tensor<9x8x7x1x2x3xf32>
1980+
}
1981+
1982+
module attributes {transform.with_named_sequence} {
1983+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1984+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1985+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1986+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
1987+
transform.yield
1988+
}
1989+
}
1990+
1991+
// -----
1992+
1993+
// Same as above, but the source type has is dynamically shaped. This means
1994+
// that the pad value is now required and the vector dim corresponding to the
1995+
// dynamic shape has to be inferred from the shape of the destination tensor.
1996+
1997+
// CHECK-LABEL: func.func @insert_dynamic_slice_non_zero_pad(
1998+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x?x3xf32>,
1999+
// CHECK-SAME: %[[PAD:.*]]: f32,
2000+
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
2001+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
2002+
// CHECK: %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
2003+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7x1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
2004+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{.*}}, %[[PAD]] {in_bounds = [true, false, true]} : tensor<1x?x3xf32>, vector<1x2x3xf32>
2005+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[WRITE]]{{.*}} {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
2006+
// CHECK: return %[[RES]] : tensor<9x8x7x1x2x3xf32>
2007+
func.func @insert_dynamic_slice_non_zero_pad(%arg0: tensor<1x?x3xf32>, %pad : f32, %size: index) -> tensor<9x8x7x1x2x3xf32> {
2008+
%init = tensor.empty() : tensor<9x8x7x1x2x3xf32>
2009+
%fill = linalg.fill ins(%pad : f32) outs(%init : tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32>
2010+
%res = tensor.insert_slice %arg0 into %fill[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, %size, 3][1, 1, 1, 1, 1, 1] : tensor<1x?x3xf32> into tensor<9x8x7x1x2x3xf32>
2011+
return %res : tensor<9x8x7x1x2x3xf32>
19492012
}
19502013

19512014
module attributes {transform.with_named_sequence} {

0 commit comments

Comments
 (0)