Skip to content

Commit f6d55c3

Browse files
committed
[mlir] Add apply_patterns.linalg.pad_vectorization TD Op
This PR simply wraps `populatePadOpVectorizationPatterns` into a new Transform Dialect Op: `apply_patterns.linalg.pad_vectorization`. This change makes it possible to run (and test) the corresponding patterns without: `transform.structured.vectorize_children_and_apply_patterns`. Note that the Op above only supports non-masked vectorisation (i.e. when the inputs are static), so, effectively, only fixed-width vectorisation (as opposed to scalable vectorisation). As such, this change is required to construct vectorization pipelines for tensor.pad targeting scalable vectors. To test the new Op and the corresponding patterns, I added "vectorization-pad-patterns.mlir" - most tests have been extracted from "vectorization-with-patterns.mlir". As a side note, I feel that we should move `GenericPadOpVectorizationPattern` out of `populatePadOpVectorizationPatterns` as that's a "lower tensor.pad" rather than a "vectorize tensor.pad" pattern. I am leaving that as a TODO.
1 parent 37ad65f commit f6d55c3

File tree

5 files changed

+295
-144
lines changed

5 files changed

+295
-144
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
8484
let assemblyFormat = "attr-dict";
8585
}
8686

87+
def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
88+
"apply_patterns.linalg.pad_vectorization",
89+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
90+
let description = [{
91+
Apply patterns that take tensor.pad and rewrites it as
92+
vector.transfer_read/vector.transfer_write Ops.
93+
94+
These patterns will either fold tensor.pad with an existing
95+
vector.transfer_read or vector.transfer_write producer/consumers (requires
96+
other surrounding Ops to be already vectorised) or rewrite it, together
97+
with tensor.insert_slice consumer, as a vector.transfer_read +
98+
vector.transfer_write pair.
99+
}];
100+
101+
let assemblyFormat = "attr-dict";
102+
}
103+
87104
//===----------------------------------------------------------------------===//
88105
// BufferizeToAllocationOp
89106
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
253253
linalg::populateFoldAddIntoDestPatterns(patterns);
254254
}
255255

256+
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
257+
RewritePatternSet &patterns) {
258+
linalg::populatePadOpVectorizationPatterns(patterns);
259+
}
260+
256261
//===----------------------------------------------------------------------===//
257262
// BufferizeToAllocationOp
258263
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
22852285
return result;
22862286
}
22872287

2288-
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2288+
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp/GenerateOp and
22892289
/// InsertSliceOp. For now, only constant padding values are supported.
22902290
/// If there is enough static type information, TransferReadOps and
22912291
/// TransferWriteOps may be generated instead of InsertSliceOps.
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2+
3+
///----------------------------------------------------------------------------------------
4+
/// [Pattern: PadOpVectorizationWithTransferReadPattern]
5+
///----------------------------------------------------------------------------------------
6+
// CHECK-LABEL: func @pad_and_transfer_read
7+
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
8+
// CHECK-NOT: tensor.pad
9+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
10+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5.0
11+
// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
12+
// CHECK: return %[[RESULT]]
13+
func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
14+
%c0 = arith.constant 0 : index
15+
%c5 = arith.constant 5.0 : f32
16+
%c6 = arith.constant 6.0 : f32
17+
%0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
18+
^bb0(%arg1: index, %arg2: index):
19+
tensor.yield %c5 : f32
20+
} : tensor<5x6xf32> to tensor<10x13xf32>
21+
%1 = vector.transfer_read %0[%c0, %c0], %c6
22+
: tensor<10x13xf32>, vector<7x9xf32>
23+
return %1 : vector<7x9xf32>
24+
}
25+
26+
module attributes {transform.with_named_sequence} {
27+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
28+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
29+
30+
transform.apply_patterns to %func_op {
31+
transform.apply_patterns.linalg.pad_vectorization
32+
} : !transform.op<"func.func">
33+
transform.yield
34+
}
35+
}
36+
37+
// -----
38+
39+
///----------------------------------------------------------------------------------------
40+
/// [Pattern: PadOpVectorizationWithTransferReadPattern
41+
///----------------------------------------------------------------------------------------
42+
func.func private @make_vector() -> vector<7x9xf32>
43+
44+
// CHECK-LABEL: func @pad_and_transfer_write_static
45+
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
46+
// CHECK-NOT: tensor.pad
47+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
48+
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
49+
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
50+
// CHECK: return %[[RESULT]]
51+
func.func @pad_and_transfer_write_static(
52+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
53+
%c0 = arith.constant 0 : index
54+
%c5 = arith.constant 5.0 : f32
55+
%0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
56+
^bb0(%arg2: index, %arg3: index):
57+
tensor.yield %c5 : f32
58+
} : tensor<5x6xf32> to tensor<10x13xf32>
59+
%1 = call @make_vector() : () -> vector<7x9xf32>
60+
%2 = vector.transfer_write %1, %0[%c0, %c0]
61+
: vector<7x9xf32>, tensor<10x13xf32>
62+
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
63+
return %3 : tensor<5x6xf32>
64+
}
65+
66+
module attributes {transform.with_named_sequence} {
67+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
68+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
69+
70+
transform.apply_patterns to %func_op {
71+
transform.apply_patterns.linalg.pad_vectorization
72+
} : !transform.op<"func.func">
73+
transform.yield
74+
}
75+
}
76+
77+
// -----
78+
79+
func.func private @make_vector() -> vector<7x9xf32>
80+
81+
// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static
82+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
83+
// CHECK-NOT: tensor.pad
84+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
85+
// CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
86+
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
87+
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
88+
// CHECK: return %[[RESULT]]
89+
func.func @pad_and_transfer_write_dynamic_static(
90+
%arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
91+
%c0 = arith.constant 0 : index
92+
%c5 = arith.constant 5.0 : f32
93+
%s = tensor.extract_slice %arg0[0, 0] [%size, 6] [1, 1]
94+
: tensor<?x?xf32> to tensor<?x6xf32>
95+
%0 = tensor.pad %s low[0, 0] high[%padding, 7] {
96+
^bb0(%arg2: index, %arg3: index):
97+
tensor.yield %c5 : f32
98+
} : tensor<?x6xf32> to tensor<?x13xf32>
99+
%1 = call @make_vector() : () -> vector<7x9xf32>
100+
%2 = vector.transfer_write %1, %0[%c0, %c0]
101+
: vector<7x9xf32>, tensor<?x13xf32>
102+
%3 = tensor.extract_slice %2[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
103+
return %3 : tensor<?x6xf32>
104+
}
105+
106+
module attributes {transform.with_named_sequence} {
107+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
108+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
109+
110+
transform.apply_patterns to %func_op {
111+
transform.apply_patterns.linalg.pad_vectorization
112+
} : !transform.op<"func.func">
113+
transform.yield
114+
}
115+
}
116+
117+
118+
// -----
119+
120+
///----------------------------------------------------------------------------------------
121+
/// [Pattern: PadOpVectorizationWithInsertSlicePattern]
122+
///----------------------------------------------------------------------------------------
123+
124+
func.func private @make_vector() -> tensor<12x13xf32>
125+
126+
// CHECK-LABEL: func @pad_and_insert_slice_source
127+
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
128+
// CHECK-NOT: tensor.pad
129+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
130+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5.0
131+
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32>
132+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
133+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
134+
// CHECK: return %[[WRITE]]
135+
func.func @pad_and_insert_slice_source(
136+
%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
137+
%c0 = arith.constant 0 : index
138+
%c5 = arith.constant 5.0 : f32
139+
%0 = tensor.pad %arg0 low[0, 0] high[2, 3] {
140+
^bb0(%arg2: index, %arg3: index):
141+
tensor.yield %c5 : f32
142+
} : tensor<5x6xf32> to tensor<7x9xf32>
143+
%1 = call @make_vector() : () -> tensor<12x13xf32>
144+
%r = tensor.insert_slice %0 into %1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
145+
return %r : tensor<12x13xf32>
146+
}
147+
148+
module attributes {transform.with_named_sequence} {
149+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
150+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
151+
152+
transform.apply_patterns to %func_op {
153+
transform.apply_patterns.linalg.pad_vectorization
154+
} : !transform.op<"func.func">
155+
transform.yield
156+
}
157+
}
158+
159+
160+
// -----
161+
162+
///----------------------------------------------------------------------------------------
163+
/// tensor::PadOp -> tensor::EmptyOp + linalg::FillOp/tensor::GenerateOp + tensor::InsertSliceOp
164+
/// [Pattern: GenericPadOpVectorizationPattern]
165+
///----------------------------------------------------------------------------------------
166+
167+
func.func private @make_vector() -> tensor<12x13xf32>
168+
169+
// Same as @pad_and_insert_slice_dest in vectorization-wit-patterns.mlir, but
170+
// CHECK-LABEL: func.func @pad_and_insert_slice_dest(
171+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
172+
// CHECK-NOT: tensor.pad
173+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
174+
// CHECK-DAG: %[[PAD:.*]] = arith.constant 5.000000e+00 : f32
175+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
176+
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[PAD]] : f32) outs(%[[EMPTY]] : tensor<1x12x13xf32>) -> tensor<1x12x13xf32>
177+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
178+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
179+
// CHECK: %[[VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
180+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[VEC]] into %[[WRITE]][0, 0, 0] [1, 12, 13] [1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
181+
// CHECK: return %[[RES]] : tensor<1x12x13xf32>
182+
183+
func.func @pad_and_insert_slice_dest(
184+
%arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
185+
%c5 = arith.constant 5.0 : f32
186+
%0 = tensor.pad %arg0 low[0, 0, 0] high[0, 7, 7] {
187+
^bb0(%arg2: index, %arg3: index, %arg4: index):
188+
tensor.yield %c5 : f32
189+
} : tensor<1x5x6xf32> to tensor<1x12x13xf32>
190+
%1 = call @make_vector() : () -> tensor<12x13xf32>
191+
%r = tensor.insert_slice %1 into %0[0, 0, 0][1, 12, 13][1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
192+
return %r : tensor<1x12x13xf32>
193+
}
194+
195+
module attributes {transform.with_named_sequence} {
196+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
197+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
198+
199+
transform.apply_patterns to %func_op {
200+
transform.apply_patterns.linalg.pad_vectorization
201+
} : !transform.op<"func.func">
202+
transform.yield
203+
}
204+
}
205+
206+
// -----
207+
func.func private @make_vector() -> vector<7x9xf32>
208+
209+
// Variant of @pad_and_transfer_write_static
210+
211+
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
212+
// CHECK-NOT: tensor.pad
213+
// CHECK: linalg.fill
214+
func.func @pad_and_transfer_write_static_non_zero_low_pad(
215+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
216+
%c0 = arith.constant 0 : index
217+
%c5 = arith.constant 5.0 : f32
218+
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
219+
^bb0(%arg2: index, %arg3: index):
220+
tensor.yield %c5 : f32
221+
} : tensor<5x6xf32> to tensor<10x13xf32>
222+
%1 = call @make_vector() : () -> vector<7x9xf32>
223+
%2 = vector.transfer_write %1, %0[%c0, %c0]
224+
: vector<7x9xf32>, tensor<10x13xf32>
225+
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
226+
return %3 : tensor<5x6xf32>
227+
}
228+
229+
module attributes {transform.with_named_sequence} {
230+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
231+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
232+
233+
transform.apply_patterns to %func_op {
234+
transform.apply_patterns.linalg.pad_vectorization
235+
} : !transform.op<"func.func">
236+
transform.yield
237+
}
238+
}
239+
240+
// -----
241+
func.func private @make_vector() -> vector<7x9xf32>
242+
243+
// Variant of @pad_and_transfer_write_static
244+
245+
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
246+
// CHECK-NOT: tensor.pad
247+
// CHECK: linalg.fill
248+
func.func @pad_and_transfer_write_static_non_zero_offset(
249+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
250+
%c0 = arith.constant 0 : index
251+
%c5 = arith.constant 5.0 : f32
252+
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
253+
^bb0(%arg2: index, %arg3: index):
254+
tensor.yield %c5 : f32
255+
} : tensor<5x6xf32> to tensor<10x13xf32>
256+
%1 = call @make_vector() : () -> vector<7x9xf32>
257+
%2 = vector.transfer_write %1, %0[%c0, %c0]
258+
: vector<7x9xf32>, tensor<10x13xf32>
259+
%3 = tensor.extract_slice %2[0, 1] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
260+
return %3 : tensor<5x6xf32>
261+
}
262+
263+
module attributes {transform.with_named_sequence} {
264+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
265+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
266+
267+
transform.apply_patterns to %func_op {
268+
transform.apply_patterns.linalg.pad_vectorization
269+
} : !transform.op<"func.func">
270+
transform.yield
271+
}
272+
}

0 commit comments

Comments
 (0)