Skip to content

Commit 1280b52

Browse files
committed
[mlir][nfc] Update 2 tests for PadOpVectorizationWithTransferWritePattern
* Relocates two tests for `PadOpVectorizationWithTransferWritePattern` in "vectorization-pad-patterns.mlir" to group them with other tests for the same pattern. * Adds a note clarifying that these are negative tests and explains the reasoning behind them. * Removes `transform.apply_patterns.linalg.decompose_pad` from the TD sequences as it's no longer needed (*). This is essentially a small clean-up in preparation for upcoming changes. (*) `transform.apply_patterns.linalg.decompose_pad` was split off from `transform.apply_patterns.linalg.pad_vectorization` in llvm#117329. "vectorization-pad-patterns.mlir" is meant to test the latter, not the former.
1 parent d0b641b commit 1280b52

File tree

1 file changed

+68
-72
lines changed

1 file changed

+68
-72
lines changed

mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir

Lines changed: 68 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,74 @@ module attributes {transform.with_named_sequence} {
114114
}
115115
}
116116

117+
// -----
118+
119+
func.func private @make_vector() -> vector<7x9xf32>
120+
121+
// Negative test - low pad is non-zero
122+
123+
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
124+
// CHECK: tensor.pad
125+
func.func @pad_and_transfer_write_static_non_zero_low_pad(
126+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
127+
%c0 = arith.constant 0 : index
128+
%c5 = arith.constant 5.0 : f32
129+
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
130+
^bb0(%arg2: index, %arg3: index):
131+
tensor.yield %c5 : f32
132+
} : tensor<5x6xf32> to tensor<10x13xf32>
133+
%1 = call @make_vector() : () -> vector<7x9xf32>
134+
%2 = vector.transfer_write %1, %0[%c0, %c0]
135+
: vector<7x9xf32>, tensor<10x13xf32>
136+
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
137+
return %3 : tensor<5x6xf32>
138+
}
139+
140+
module attributes {transform.with_named_sequence} {
141+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
142+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
143+
144+
transform.apply_patterns to %func_op {
145+
transform.apply_patterns.linalg.pad_vectorization
146+
} : !transform.op<"func.func">
147+
transform.yield
148+
}
149+
}
150+
151+
// -----
152+
153+
// Negative test - TransferWriteOp result is not _directly_ consumed by an
154+
// ExtractSliceOp (noet the non-zero offset).
155+
156+
func.func private @make_vector() -> vector<7x9xf32>
157+
158+
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
159+
// CHECK: tensor.pad
160+
func.func @pad_and_transfer_write_static_non_zero_offset(
161+
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
162+
%c0 = arith.constant 0 : index
163+
%c5 = arith.constant 5.0 : f32
164+
%0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
165+
^bb0(%arg2: index, %arg3: index):
166+
tensor.yield %c5 : f32
167+
} : tensor<5x6xf32> to tensor<10x13xf32>
168+
%1 = call @make_vector() : () -> vector<7x9xf32>
169+
%2 = vector.transfer_write %1, %0[%c0, %c0]
170+
: vector<7x9xf32>, tensor<10x13xf32>
171+
%3 = tensor.extract_slice %2[0, 1] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
172+
return %3 : tensor<5x6xf32>
173+
}
174+
175+
module attributes {transform.with_named_sequence} {
176+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
177+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
178+
179+
transform.apply_patterns to %func_op {
180+
transform.apply_patterns.linalg.pad_vectorization
181+
} : !transform.op<"func.func">
182+
transform.yield
183+
}
184+
}
117185

118186
// -----
119187

@@ -209,75 +277,3 @@ module attributes {transform.with_named_sequence} {
209277
transform.yield
210278
}
211279
}
212-
213-
// -----
214-
func.func private @make_vector() -> vector<7x9xf32>
215-
216-
// Variant of @pad_and_transfer_write_static
217-
218-
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
219-
// CHECK-NOT: tensor.pad
220-
// CHECK: linalg.fill
221-
func.func @pad_and_transfer_write_static_non_zero_low_pad(
222-
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
223-
%c0 = arith.constant 0 : index
224-
%c5 = arith.constant 5.0 : f32
225-
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
226-
^bb0(%arg2: index, %arg3: index):
227-
tensor.yield %c5 : f32
228-
} : tensor<5x6xf32> to tensor<10x13xf32>
229-
%1 = call @make_vector() : () -> vector<7x9xf32>
230-
%2 = vector.transfer_write %1, %0[%c0, %c0]
231-
: vector<7x9xf32>, tensor<10x13xf32>
232-
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
233-
return %3 : tensor<5x6xf32>
234-
}
235-
236-
module attributes {transform.with_named_sequence} {
237-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
238-
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
239-
240-
transform.apply_patterns to %func_op {
241-
// TODO: Split into two tests, one for each pattern
242-
transform.apply_patterns.linalg.decompose_pad
243-
transform.apply_patterns.linalg.pad_vectorization
244-
} : !transform.op<"func.func">
245-
transform.yield
246-
}
247-
}
248-
249-
// -----
250-
func.func private @make_vector() -> vector<7x9xf32>
251-
252-
// Variant of @pad_and_transfer_write_static
253-
254-
// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
255-
// CHECK-NOT: tensor.pad
256-
// CHECK: linalg.fill
257-
func.func @pad_and_transfer_write_static_non_zero_offset(
258-
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
259-
%c0 = arith.constant 0 : index
260-
%c5 = arith.constant 5.0 : f32
261-
%0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
262-
^bb0(%arg2: index, %arg3: index):
263-
tensor.yield %c5 : f32
264-
} : tensor<5x6xf32> to tensor<10x13xf32>
265-
%1 = call @make_vector() : () -> vector<7x9xf32>
266-
%2 = vector.transfer_write %1, %0[%c0, %c0]
267-
: vector<7x9xf32>, tensor<10x13xf32>
268-
%3 = tensor.extract_slice %2[0, 1] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
269-
return %3 : tensor<5x6xf32>
270-
}
271-
272-
module attributes {transform.with_named_sequence} {
273-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
274-
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
275-
276-
transform.apply_patterns to %func_op {
277-
// TODO: Split into two tests, one for each pattern
278-
transform.apply_patterns.linalg.decompose_pad
279-
transform.apply_patterns.linalg.pad_vectorization
280-
} : !transform.op<"func.func">
281-
transform.yield
282-
}
283-
}

0 commit comments

Comments
 (0)