Skip to content

Commit 263ede9

Browse files
committed
Add VectorToSPIRV patterns to GPUToSPIRVPass, and fix errors in e2e test
1 parent e7a41e8 commit 263ede9

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1717
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
1818
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
19+
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1920
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
2021
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
2122
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -132,6 +133,7 @@ void GPUToSPIRVPass::runOnOperation() {
132133
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
133134
populateMemRefToSPIRVPatterns(typeConverter, patterns);
134135
populateFuncToSPIRVPatterns(typeConverter, patterns);
136+
populateVectorToSPIRVPatterns(typeConverter, patterns);
135137

136138
if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
137139
return signalPassFailure();

mlir/test/mlir-vulkan-runner/vector-interleave.mlir

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,41 @@ module attributes {
99
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
1010
} {
1111
gpu.module @kernels {
12-
gpu.func @kernel_vector_interleave(%arg0 : vector<2xi32>, %arg1 : vector<2xi32>, %arg2 : memref<4xi32>)
12+
gpu.func @kernel_vector_interleave(%arg0 : memref<2xi32>, %arg1 : memref<2xi32>, %arg2 : memref<4xi32>)
1313
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
1414
%c0 = arith.constant 0 : index
15-
%result = vector.interleave %arg0, %arg1 : vector<2xi32>
15+
%vec0 = vector.load %arg0[%c0] : memref<2xi32>, vector<2xi32>
16+
%vec1 = vector.load %arg1[%c0] : memref<2xi32>, vector<2xi32>
17+
%result = vector.interleave %vec0, %vec1 : vector<2xi32> -> vector<4xi32>
1618
vector.store %result, %arg2[%c0] : memref<4xi32>, vector<4xi32>
1719
gpu.return
1820
}
1921
}
2022

2123
func.func @main() {
2224
// Allocate 3 buffers.
23-
%buf0 = arith.constant dense<[0, 1]> : vector<2xi32>
24-
%buf1 = arith.constant dense<[2, 3]> : vector<2xi32>
25+
%buf0 = memref.alloc() : memref<2xi32>
26+
%buf1 = memref.alloc() : memref<2xi32>
2527
%buf2 = memref.alloc() : memref<4xi32>
2628

2729
%idx0 = arith.constant 0 : index
2830
%idx1 = arith.constant 1 : index
2931
%idx4 = arith.constant 4 : index
3032

33+
// Initialize input buffer
34+
%buf0_vals = arith.constant dense<[0, 1]> : vector<2xi32>
35+
%buf1_vals = arith.constant dense<[2, 3]> : vector<2xi32>
36+
vector.store %buf0_vals, %buf0[%idx0] : memref<2xi32>, vector<2xi32>
37+
vector.store %buf1_vals, %buf1[%idx0] : memref<2xi32>, vector<2xi32>
38+
3139
// Initialize output buffer.
3240
%value0 = arith.constant 0 : i32
3341
%buf3 = memref.cast %buf2 : memref<4xi32> to memref<?xi32>
3442
call @fillResource1DInt(%buf3, %value0) : (memref<?xi32>, i32) -> ()
3543

3644
gpu.launch_func @kernels::@kernel_vector_interleave
3745
blocks in (%idx4, %idx1, %idx1) threads in (%idx1, %idx1, %idx1)
38-
args(%buf0 : vector<2xi32>, %buf1 : vector<2xi32>, %buf2 : memref<4xi32>)
46+
args(%buf0 : memref<2xi32>, %buf1 : memref<2xi32>, %buf2 : memref<4xi32>)
3947
%buf4 = memref.cast %buf3 : memref<?xi32> to memref<*xi32>
4048
call @printMemrefI32(%buf4) : (memref<*xi32>) -> ()
4149
return

0 commit comments

Comments
 (0)