Skip to content

Commit 01b1b0c

Browse files
authored
[mlir][SVE] Add e2e for 1D depthwise WC convolution (#85225)
Follow-up for #81625
1 parent 631e54a commit 01b1b0c

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// DEFINE: %{compile} = mlir-opt %s \
2+
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3+
// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" -lower-vector-mask -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
4+
// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
5+
// DEFINE: %{entry_point} = conv
6+
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
7+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
8+
9+
// RUN: %{compile} | %{run} | FileCheck %s
10+
11+
func.func @conv() {
12+
// Define input/output tensors
13+
%input_init = tensor.empty() : tensor<1x8x6xi32>
14+
%output_init = tensor.empty() : tensor<1x7x6xi32>
15+
16+
%five = arith.constant 5 : i32
17+
%zero = arith.constant 0 : i32
18+
%input = linalg.fill ins(%five : i32) outs(%input_init : tensor<1x8x6xi32>) -> tensor<1x8x6xi32>
19+
%output = linalg.fill ins(%zero : i32) outs(%output_init : tensor<1x7x6xi32>) -> tensor<1x7x6xi32>
20+
21+
// Define the filter tensor
22+
%filter = arith.constant dense<[
23+
[ 1, 2, 3, 4, 5, 6],
24+
[ 11, 12, 13, 14, 15, 16]
25+
]> : tensor<2x6xi32>
26+
27+
// static sizes -> dynamic sizes
28+
%input_dyn = tensor.cast %input_init : tensor<1x8x6xi32> to tensor<1x8x?xi32>
29+
%output_dyn = tensor.cast %output : tensor<1x7x6xi32> to tensor<1x7x?xi32>
30+
%filter_dyn = tensor.cast %filter : tensor<2x6xi32> to tensor<2x?xi32>
31+
32+
// Run the convolution
33+
%res = linalg.depthwise_conv_1d_nwc_wc
34+
ins(%input_dyn, %filter_dyn : tensor<1x8x?xi32>, tensor<2x?xi32>)
35+
outs(%output_dyn : tensor<1x7x?xi32>) -> tensor<1x7x?xi32>
36+
37+
// Print the results
38+
// CHECK: SVE: START OF TEST OUTPUT
39+
vector.print str "SVE: START OF TEST OUTPUT\n"
40+
41+
// CHECK-NEXT: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 7, 6] strides = [42, 6, 1] data =
42+
// CHECK-COUNT-7: [60, 70, 80, 90, 100, 110]
43+
%xf = tensor.cast %res : tensor<1x7x?xi32> to tensor<*xi32>
44+
call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
45+
46+
// CHECK-NEXT: SVE: END OF TEST OUTPUT
47+
vector.print str "SVE: END OF TEST OUTPUT\n"
48+
49+
return
50+
}
51+
52+
module attributes {transform.with_named_sequence} {
53+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
54+
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
55+
transform.structured.vectorize %0 vector_sizes [1, 7, [8], 2] : !transform.any_op
56+
transform.yield
57+
}
58+
}
59+
60+
func.func private @printMemrefI32(%ptr : tensor<*xi32>) attributes { llvm.emit_c_interface }

0 commit comments

Comments
 (0)