Skip to content

Commit ac4bd74

Browse files
authored
[mlir] Add apply_patterns.linalg.pad_vectorization TD Op (#112504)
This PR simply wraps `populatePadOpVectorizationPatterns` into a new Transform Dialect Op: `apply_patterns.linalg.pad_vectorization`. This change makes it possible to run (and test) the corresponding patterns _without_: `transform.structured.vectorize_children_and_apply_patterns`. Note that the Op above only supports non-masked vectorisation (i.e. when the inputs are static), so, effectively, only fixed-width vectorisation (as opposed to scalable vectorisation). As such, this change is required to construct vectorization pipelines for tensor.pad targeting scalable vectors. To test the new Op and the corresponding patterns, I added "vectorization-pad-patterns.mlir" - most tests have been extracted from "vectorization-with-patterns.mlir".
1 parent 6e73750 commit ac4bd74

File tree

5 files changed

+302
-143
lines changed

5 files changed

+302
-143
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
8484
let assemblyFormat = "attr-dict";
8585
}
8686

87+
def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
88+
"apply_patterns.linalg.pad_vectorization",
89+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
90+
let description = [{
91+
Apply patterns that vectorize tensor.pad.
92+
93+
These patterns rewrite tensor.pad Ops using vector.transfer_read and
94+
vector.transfer_write operations. This is done either by:
95+
1. Folding tensor.pad with an existing vector.transfer_read /
96+
vector.transfer_write Op (generated prior to running these patterns).
97+
2. Rewriting it (when matched together with q tensor.insert_slice
98+
consumer Op) as a vector.transfer_read + vector.transfer_write pair.
99+
100+
In both cases, these patterns look at producers and consumers for the
101+
matched tensor.pad Op to find opportunities for vectorization.
102+
}];
103+
104+
let assemblyFormat = "attr-dict";
105+
}
106+
87107
//===----------------------------------------------------------------------===//
88108
// BufferizeToAllocationOp
89109
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
253253
linalg::populateFoldAddIntoDestPatterns(patterns);
254254
}
255255

256+
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
257+
RewritePatternSet &patterns) {
258+
linalg::populatePadOpVectorizationPatterns(patterns);
259+
}
260+
256261
//===----------------------------------------------------------------------===//
257262
// BufferizeToAllocationOp
258263
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,6 +2712,9 @@ struct PadOpVectorizationWithInsertSlicePattern
27122712

27132713
void mlir::linalg::populatePadOpVectorizationPatterns(
27142714
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2715+
// TODO: The following pattern implements "decomposition" and
2716+
// optional "vectorization". Seperate "decomposition" into a sepereate
2717+
// pre-processing pattern group.
27152718
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
27162719
baseBenefit);
27172720
// Try these specialized patterns first before resorting to the generic one.
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2+
3+
///----------------------------------------------------------------------------------------
4+
/// [Pattern: PadOpVectorizationWithTransferReadPattern]
5+
///----------------------------------------------------------------------------------------
6+
// CHECK-LABEL: func @pad_and_transfer_read
7+
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
8+
// CHECK-NOT: tensor.pad
9+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
10+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5.0
11+
// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
12+
// CHECK: return %[[RESULT]]
13+
func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
14+
%c0 = arith.constant 0 : index
15+
%c5 = arith.constant 5.0 : f32
16+
%c6 = arith.constant 6.0 : f32
17+
%0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
18+
^bb0(%arg1: index, %arg2: index):
19+
tensor.yield %c5 : f32
20+
} : tensor<5x6xf32> to tensor<10x13xf32>
21+
%1 = vector.transfer_read %0[%c0, %c0], %c6
22+
: tensor<10x13xf32>, vector<7x9xf32>
23+
return %1 : vector<7x9xf32>
24+
}
25+
26+
module attributes {transform.with_named_sequence} {
27+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
28+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
29+
30+
transform.apply_patterns to %func_op {
31+
transform.apply_patterns.linalg.pad_vectorization
32+
} : !transform.op<"func.func">
33+
transform.yield
34+
}
35+
}
36+
37+
// -----
38+
39+
///----------------------------------------------------------------------------------------
40+
/// [Pattern: PadOpVectorizationWithTransferWritePattern]
41+
///----------------------------------------------------------------------------------------
42+
func.func private @make_vector() -> vector<7x9xf32>
43+
44+
// CHECK-LABEL: func @pad_and_transfer_write_static_low_and_high
45+
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
46+
// CHECK-NOT: tensor.pad
47+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
48+
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
49+
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
50+
// CHECK: return %[[RESULT]]
51+
func.func @pad_and_transfer_write_static_low_and_high(
52+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
53+
%c0 = arith.constant 0 : index
54+
%c5 = arith.constant 5.0 : f32
55+
%0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
56+
^bb0(%arg2: index, %arg3: index):
57+
tensor.yield %c5 : f32
58+
} : tensor<5x6xf32> to tensor<10x13xf32>
59+
%1 = call @make_vector() : () -> vector<7x9xf32>
60+
%2 = vector.transfer_write %1, %0[%c0, %c0]
61+
: vector<7x9xf32>, tensor<10x13xf32>
62+
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
63+
return %3 : tensor<5x6xf32>
64+
}
65+
66+
module attributes {transform.with_named_sequence} {
67+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
68+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
69+
70+
transform.apply_patterns to %func_op {
71+
transform.apply_patterns.linalg.pad_vectorization
72+
} : !transform.op<"func.func">
73+
transform.yield
74+
}
75+
}
76+
77+
// -----
78+
79+
func.func private @make_vector() -> vector<7x9xf32>
80+
81+
// CHECK-LABEL: func @pad_and_transfer_write_static_low_dynamic_high
82+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
83+
// CHECK-NOT: tensor.pad
84+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
85+
// CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
86+
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
87+
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
88+
// CHECK: return %[[RESULT]]
89+
func.func @pad_and_transfer_write_static_low_dynamic_high(
90+
%arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
91+
%c0 = arith.constant 0 : index
92+
%c5 = arith.constant 5.0 : f32
93+
%s = tensor.extract_slice %arg0[0, 0] [%size, 6] [1, 1]
94+
: tensor<?x?xf32> to tensor<?x6xf32>
95+
%0 = tensor.pad %s low[0, 0] high[%padding, 7] {
96+
^bb0(%arg2: index, %arg3: index):
97+
tensor.yield %c5 : f32
98+
} : tensor<?x6xf32> to tensor<?x13xf32>
99+
%1 = call @make_vector() : () -> vector<7x9xf32>
100+
%2 = vector.transfer_write %1, %0[%c0, %c0]
101+
: vector<7x9xf32>, tensor<?x13xf32>
102+
%3 = tensor.extract_slice %2[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
103+
return %3 : tensor<?x6xf32>
104+
}
105+
106+
module attributes {transform.with_named_sequence} {
107+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
108+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
109+
110+
transform.apply_patterns to %func_op {
111+
transform.apply_patterns.linalg.pad_vectorization
112+
} : !transform.op<"func.func">
113+
transform.yield
114+
}
115+
}
116+
117+
118+
// -----
119+
120+
///----------------------------------------------------------------------------------------
121+
/// [Pattern: PadOpVectorizationWithInsertSlicePattern]
122+
///----------------------------------------------------------------------------------------
123+
124+
func.func private @make_vector() -> tensor<12x13xf32>
125+
126+
// CHECK-LABEL: func @pad_and_insert_slice_source
127+
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
128+
// CHECK-NOT: tensor.pad
129+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
130+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5.0
131+
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32>
132+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
133+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
134+
// CHECK: return %[[WRITE]]
135+
func.func @pad_and_insert_slice_source(
136+
%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
137+
%c0 = arith.constant 0 : index
138+
%c5 = arith.constant 5.0 : f32
139+
%0 = tensor.pad %arg0 low[0, 0] high[2, 3] {
140+
^bb0(%arg2: index, %arg3: index):
141+
tensor.yield %c5 : f32
142+
} : tensor<5x6xf32> to tensor<7x9xf32>
143+
%1 = call @make_vector() : () -> tensor<12x13xf32>
144+
%r = tensor.insert_slice %0 into %1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
145+
return %r : tensor<12x13xf32>
146+
}
147+
148+
module attributes {transform.with_named_sequence} {
149+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
150+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
151+
152+
transform.apply_patterns to %func_op {
153+
transform.apply_patterns.linalg.pad_vectorization
154+
} : !transform.op<"func.func">
155+
transform.yield
156+
}
157+
}
158+
159+
160+
// -----
161+
162+
///----------------------------------------------------------------------------------------
163+
/// tensor::PadOp -> tensor::EmptyOp + linalg::FillOp/tensor::GenerateOp + tensor::InsertSliceOp
164+
/// [Pattern: GenericPadOpVectorizationPattern]
165+
///----------------------------------------------------------------------------------------
166+
167+
func.func private @make_vector() -> tensor<12x13xf32>
168+
169+
// Same as @pad_and_insert_slice_dest in vectorization-with-patterns.mlir, but
170+
// over here linalg::fill is not vectorized (patterns for linalg.fill are not
171+
// included here)
172+
// CHECK-LABEL: func.func @pad_and_insert_slice_dest(
173+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
174+
// CHECK-NOT: tensor.pad
175+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
176+
// CHECK-DAG: %[[PAD:.*]] = arith.constant 5.000000e+00 : f32
177+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
178+
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[PAD]] : f32) outs(%[[EMPTY]] : tensor<1x12x13xf32>) -> tensor<1x12x13xf32>
179+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
180+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
181+
// CHECK: %[[VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
182+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[VEC]] into %[[WRITE]][0, 0, 0] [1, 12, 13] [1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
183+
// CHECK: return %[[RES]] : tensor<1x12x13xf32>
184+
185+
func.func @pad_and_insert_slice_dest(
186+
%arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
187+
%c5 = arith.constant 5.0 : f32
188+
%0 = tensor.pad %arg0 low[0, 0, 0] high[0, 7, 7] {
189+
^bb0(%arg2: index, %arg3: index, %arg4: index):
190+
tensor.yield %c5 : f32
191+
} : tensor<1x5x6xf32> to tensor<1x12x13xf32>
192+
%1 = call @make_vector() : () -> tensor<12x13xf32>
193+
%r = tensor.insert_slice %1 into %0[0, 0, 0][1, 12, 13][1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
194+
return %r : tensor<1x12x13xf32>
195+
}
196+
197+
module attributes {transform.with_named_sequence} {
198+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
199+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
200+
201+
transform.apply_patterns to %func_op {
202+
transform.apply_patterns.linalg.pad_vectorization
203+
} : !transform.op<"func.func">
204+
transform.yield
205+
}
206+
}
207+
208+
// -----
209+
func.func private @make_vector() -> vector<7x9xf32>
210+
211+
// Variant of @pad_and_transfer_write_static
212+
213+
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
214+
// CHECK-NOT: tensor.pad
215+
// CHECK: linalg.fill
216+
func.func @pad_and_transfer_write_static_non_zero_low_pad(
217+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
218+
%c0 = arith.constant 0 : index
219+
%c5 = arith.constant 5.0 : f32
220+
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
221+
^bb0(%arg2: index, %arg3: index):
222+
tensor.yield %c5 : f32
223+
} : tensor<5x6xf32> to tensor<10x13xf32>
224+
%1 = call @make_vector() : () -> vector<7x9xf32>
225+
%2 = vector.transfer_write %1, %0[%c0, %c0]
226+
: vector<7x9xf32>, tensor<10x13xf32>
227+
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
228+
return %3 : tensor<5x6xf32>
229+
}
230+
231+
module attributes {transform.with_named_sequence} {
232+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
233+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
234+
235+
transform.apply_patterns to %func_op {
236+
transform.apply_patterns.linalg.pad_vectorization
237+
} : !transform.op<"func.func">
238+
transform.yield
239+
}
240+
}
241+
242+
// -----
243+
func.func private @make_vector() -> vector<7x9xf32>
244+
245+
// Variant of @pad_and_transfer_write_static
246+
247+
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
248+
// CHECK-NOT: tensor.pad
249+
// CHECK: linalg.fill
250+
func.func @pad_and_transfer_write_static_non_zero_offset(
251+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
252+
%c0 = arith.constant 0 : index
253+
%c5 = arith.constant 5.0 : f32
254+
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
255+
^bb0(%arg2: index, %arg3: index):
256+
tensor.yield %c5 : f32
257+
} : tensor<5x6xf32> to tensor<10x13xf32>
258+
%1 = call @make_vector() : () -> vector<7x9xf32>
259+
%2 = vector.transfer_write %1, %0[%c0, %c0]
260+
: vector<7x9xf32>, tensor<10x13xf32>
261+
%3 = tensor.extract_slice %2[0, 1] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
262+
return %3 : tensor<5x6xf32>
263+
}
264+
265+
module attributes {transform.with_named_sequence} {
266+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
267+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
268+
269+
transform.apply_patterns to %func_op {
270+
transform.apply_patterns.linalg.pad_vectorization
271+
} : !transform.op<"func.func">
272+
transform.yield
273+
}
274+
}

0 commit comments

Comments
 (0)