@@ -9,33 +9,41 @@ module attributes {
9
9
#spirv.vce <v1.0 , [Shader ], [SPV_KHR_storage_buffer_storage_class ]>, #spirv.resource_limits <>>
10
10
} {
11
11
gpu.module @kernels {
12
- gpu.func @kernel_vector_interleave (%arg0 : vector <2 xi32 >, %arg1 : vector <2 xi32 >, %arg2 : memref <4 xi32 >)
12
+ gpu.func @kernel_vector_interleave (%arg0 : memref <2 xi32 >, %arg1 : memref <2 xi32 >, %arg2 : memref <4 xi32 >)
13
13
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi <workgroup_size = [1 , 1 , 1 ]>} {
14
14
%c0 = arith.constant 0 : index
15
- %result = vector.interleave %arg0 , %arg1 : vector <2 xi32 >
15
+ %vec0 = vector.load %arg0 [%c0 ] : memref <2 xi32 >, vector <2 xi32 >
16
+ %vec1 = vector.load %arg1 [%c0 ] : memref <2 xi32 >, vector <2 xi32 >
17
+ %result = vector.interleave %vec0 , %vec1 : vector <2 xi32 > -> vector <4 xi32 >
16
18
vector.store %result , %arg2 [%c0 ] : memref <4 xi32 >, vector <4 xi32 >
17
19
gpu.return
18
20
}
19
21
}
20
22
21
23
func.func @main () {
22
24
// Allocate 3 buffers.
23
- %buf0 = arith.constant dense <[ 0 , 1 ]> : vector <2 xi32 >
24
- %buf1 = arith.constant dense <[ 2 , 3 ]> : vector <2 xi32 >
25
+ %buf0 = memref.alloc () : memref <2 xi32 >
26
+ %buf1 = memref.alloc () : memref <2 xi32 >
25
27
%buf2 = memref.alloc () : memref <4 xi32 >
26
28
27
29
%idx0 = arith.constant 0 : index
28
30
%idx1 = arith.constant 1 : index
29
31
%idx4 = arith.constant 4 : index
30
32
33
+ // Initialize input buffer
34
+ %buf0_vals = arith.constant dense <[0 , 1 ]> : vector <2 xi32 >
35
+ %buf1_vals = arith.constant dense <[2 , 3 ]> : vector <2 xi32 >
36
+ vector.store %buf0_vals , %buf0 [%idx0 ] : memref <2 xi32 >, vector <2 xi32 >
37
+ vector.store %buf1_vals , %buf1 [%idx0 ] : memref <2 xi32 >, vector <2 xi32 >
38
+
31
39
// Initialize output buffer.
32
40
%value0 = arith.constant 0 : i32
33
41
%buf3 = memref.cast %buf2 : memref <4 xi32 > to memref <?xi32 >
34
42
call @fillResource1DInt (%buf3 , %value0 ) : (memref <?xi32 >, i32 ) -> ()
35
43
36
44
gpu.launch_func @kernels ::@kernel_vector_interleave
37
45
blocks in (%idx4 , %idx1 , %idx1 ) threads in (%idx1 , %idx1 , %idx1 )
38
- args (%buf0 : vector <2 xi32 >, %buf1 : vector <2 xi32 >, %buf2 : memref <4 xi32 >)
46
+ args (%buf0 : memref <2 xi32 >, %buf1 : memref <2 xi32 >, %buf2 : memref <4 xi32 >)
39
47
%buf4 = memref.cast %buf3 : memref <?xi32 > to memref <*xi32 >
40
48
call @printMemrefI32 (%buf4 ) : (memref <*xi32 >) -> ()
41
49
return
0 commit comments