Skip to content

Commit 22cd31a

Browse files
committed
[mlir][spirv] Add support for VectorAnyINTEL capability
Allow vector of any lengths between [2-2^32-1]. VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
1 parent 34c6f20 commit 22cd31a

File tree

13 files changed

+124
-45
lines changed

13 files changed

+124
-45
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
41464146
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
41474147
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
41484148
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
4149-
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
4149+
// Remove the vector size restriction.
4150+
// Although the vector size can be upto (2^64-1), uint64,
4151+
// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
4152+
// for all practical cases.
4153+
// Also unsigned is used for the number elements for composite tyeps.
4154+
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
41504155
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
41514156
// Component type check is done in the type parser for the following SPIR-V
41524157
// dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
42064211
"Joint Matrix">;
42074212

42084213
class SPIRV_ScalarOrVectorOf<Type type> :
4209-
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
4214+
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
42104215

42114216
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
4212-
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
4217+
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
42134218
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
42144219

42154220
class SPIRV_MatrixOrCoopMatrixOf<Type type> :

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
184184
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
185185
return Type();
186186
}
187-
if (t.getNumElements() > 4) {
187+
// Number of elements should be between [2 - 2^32 -1],
188+
// since getNumElements() returns an unsigned, the upper limit check is
189+
// unnecessary.
190+
if (t.getNumElements() < 2) {
188191
parser.emitError(
189-
typeLoc, "vector length has to be less than or equal to 4 but found ")
192+
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
190193
<< t.getNumElements();
191194
return Type();
192195
}

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
101101
}
102102

103103
bool CompositeType::isValid(VectorType type) {
104-
return type.getRank() == 1 &&
105-
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
106-
llvm::isa<ScalarType>(type.getElementType());
104+
// Number of elements should be between [2 - 2^32 -1],
105+
// since getNumElements() returns an unsigned, the upper limit check is
106+
// unnecessary.
107+
return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
108+
type.getNumElements() >= 2;
107109
}
108110

109111
Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
171173
.Case<VectorType>([&](VectorType type) {
172174
auto vecSize = getNumElements();
173175
if (vecSize == 8 || vecSize == 16) {
174-
static const Capability caps[] = {Capability::Vector16};
175-
ArrayRef<Capability> ref(caps, std::size(caps));
176-
capabilities.push_back(ref);
176+
static constexpr Capability caps[] = {Capability::Vector16,
177+
Capability::VectorAnyINTEL};
178+
capabilities.push_back(caps);
179+
}
180+
// VectorAnyINTEL capability removes the vector size restriction and
181+
// allows the vector size to be up to (2^32-1).
182+
// Vector16 capability allows the vector size to be 8 and 16
183+
SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
184+
if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
185+
static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
186+
capabilities.push_back(caps);
177187
}
178188
return llvm::cast<ScalarType>(type.getElementType())
179189
.getCapabilities(capabilities, storage);

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ module attributes {
1111
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
1212
} {
1313

14-
func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
14+
func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
1515
// expected-error@+1 {{failed to legalize operation 'arith.subi'}}
16-
%1 = arith.subi %arg0, %arg0: vector<5xi32>
16+
%1 = arith.subi %arg0, %arg1: vector<5xi32>
1717
return
1818
}
1919

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
14071407
}
14081408

14091409
} // end module
1410+
1411+
// -----
1412+
1413+
//===----------------------------------------------------------------------===//
1414+
// VectorAnyINTEL support
1415+
//===----------------------------------------------------------------------===//
1416+
1417+
// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
1418+
// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
1419+
module attributes {
1420+
spirv.target_env = #spirv.target_env<
1421+
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
1422+
} {
1423+
1424+
// CHECK-LABEL: @any_vector
1425+
func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
1426+
// CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
1427+
%0 = arith.subi %arg0, %arg1: vector<16xi32>
1428+
return
1429+
}
1430+
1431+
// CHECK-LABEL: @max_vector
1432+
func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
1433+
// CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
1434+
%0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
1435+
return
1436+
}
1437+
1438+
1439+
// Check float vector types of any size.
1440+
// CHECK-LABEL: @float_vector58
1441+
func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
1442+
// CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
1443+
%0 = arith.addf %arg0, %arg0: vector<5xf16>
1444+
// CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
1445+
%1 = arith.mulf %arg1, %arg1: vector<8xf64>
1446+
return
1447+
}
1448+
1449+
} // end module

mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,21 @@ module attributes {
351351
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
352352
} {
353353

354-
// CHECK-NOT: spirv.func @large_vector
355-
func.func @large_vector(%arg0: vector<1024xi32>) { return }
354+
// CHECK-NOT: spirv.func @large_vector_unsupported
355+
func.func @large_vector_unsupported(%arg0: vector<1024xi32>) { return }
356+
357+
} // end module
358+
359+
360+
// -----
361+
362+
// Check that large vectors are supported with VectorAnyINTEL or VectorComputeINTEL.
363+
module attributes {
364+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
365+
} {
366+
367+
// CHECK: spirv.func @large_any_vector
368+
func.func @large_any_vector(%arg0: vector<1024xi32>) { return }
356369

357370
} // end module
358371

mlir/test/Dialect/SPIRV/IR/bit-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
137137
// -----
138138

139139
func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
140-
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
140+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
141141
%0 = spirv.BitwiseOr %arg0, %arg1 : f16
142142
return %0 : f16
143143
}
@@ -163,7 +163,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
163163
// -----
164164

165165
func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
166-
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
166+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
167167
%0 = spirv.BitwiseXor %arg0, %arg1 : f16
168168
return %0 : f16
169169
}
@@ -272,7 +272,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
272272
// -----
273273

274274
func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
275-
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
275+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
276276
%0 = spirv.BitwiseAnd %arg0, %arg1 : f16
277277
return %0 : f16
278278
}

mlir/test/Dialect/SPIRV/IR/gl-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
2727
// -----
2828

2929
func.func @exp(%arg0 : vector<5xf32>) -> () {
30-
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
30+
// CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32
3131
%2 = spirv.GL.Exp %arg0 : vector<5xf32>
3232
return
3333
}

mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
2121
// -----
2222

2323
spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
24-
// expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
24+
// expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}}
2525
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
2626
spirv.Return
2727
}
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
5757
// -----
5858

5959
spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
60-
// expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
60+
// expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}}
6161
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
6262
spirv.Return
6363
}

mlir/test/Dialect/SPIRV/IR/logical-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1)
166166

167167
func.func @logicalUnary(%arg0 : i32)
168168
{
169-
// expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
169+
// expected-error @+1 {{'operand' must be bool or vector of bool values of length 2-4294967295, but got 'i32'}}
170170
%0 = spirv.LogicalNot %arg0 : i32
171171
return
172172
}

mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () {
1818

1919
// -----
2020

21-
func.func @exp(%arg0 : i32) -> () {
22-
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
23-
%2 = spirv.CL.exp %arg0 : i32
21+
func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () {
22+
// CHECK: spirv.CL.exp {{%.*}} : vector<5xf32>
23+
%2 = spirv.CL.exp %arg0 : vector<5xf32>
2424
return
2525
}
2626

2727
// -----
2828

29-
func.func @exp(%arg0 : vector<5xf32>) -> () {
30-
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
31-
%2 = spirv.CL.exp %arg0 : vector<5xf32>
29+
func.func @exp(%arg0 : i32) -> () {
30+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
31+
%2 = spirv.CL.exp %arg0 : i32
3232
return
3333
}
3434

@@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () {
6666
return
6767
}
6868

69+
// -----
70+
71+
func.func @fabs_any_vec(%arg0 : vector<5xf32>) -> () {
72+
// CHECK: spirv.CL.fabs {{%.*}} : vector<5xf32>
73+
%2 = spirv.CL.fabs %arg0 : vector<5xf32>
74+
return
75+
}
76+
6977
func.func @fabsf64(%arg0 : f64) -> () {
7078
// CHECK: spirv.CL.fabs {{%.*}} : f64
7179
%2 = spirv.CL.fabs %arg0 : f64
@@ -82,14 +90,6 @@ func.func @fabs(%arg0 : i32) -> () {
8290

8391
// -----
8492

85-
func.func @fabs(%arg0 : vector<5xf32>) -> () {
86-
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
87-
%2 = spirv.CL.fabs %arg0 : vector<5xf32>
88-
return
89-
}
90-
91-
// -----
92-
9393
func.func @fabs(%arg0 : f32, %arg1 : f32) -> () {
9494
// expected-error @+1 {{expected ':'}}
9595
%2 = spirv.CL.fabs %arg0, %arg1 : i32
@@ -122,6 +122,14 @@ func.func @sabsvec(%arg0 : vector<3xi16>) -> () {
122122
return
123123
}
124124

125+
// -----
126+
127+
func.func @sabs_any_vec(%arg0 : vector<5xi32>) -> () {
128+
// CHECK: spirv.CL.s_abs {{%.*}} : vector<5xi32>
129+
%2 = spirv.CL.s_abs %arg0 : vector<5xi32>
130+
return
131+
}
132+
125133
func.func @sabsi64(%arg0 : i64) -> () {
126134
// CHECK: spirv.CL.s_abs {{%.*}} : i64
127135
%2 = spirv.CL.s_abs %arg0 : i64
@@ -142,13 +150,7 @@ func.func @sabs(%arg0 : f32) -> () {
142150
return
143151
}
144152

145-
// -----
146153

147-
func.func @sabs(%arg0 : vector<5xi32>) -> () {
148-
// expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
149-
%2 = spirv.CL.s_abs %arg0 : vector<5xi32>
150-
return
151-
}
152154

153155
// -----
154156

mlir/test/Target/SPIRV/arithmetic-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
66
%0 = spirv.FMul %arg0, %arg1 : f32
77
spirv.Return
88
}
9-
spirv.func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" {
10-
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<4xf32>
11-
%0 = spirv.FAdd %arg0, %arg1 : vector<4xf32>
9+
spirv.func @fadd(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) "None" {
10+
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<5xf32>
11+
%0 = spirv.FAdd %arg0, %arg1 : vector<5xf32>
1212
spirv.Return
1313
}
1414
spirv.func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" {

mlir/test/Target/SPIRV/ocl-ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
3939
spirv.Return
4040
}
4141

42+
spirv.func @vector_anysize(%arg0 : vector<5000xf32>) "None" {
43+
// CHECK: {{%.*}} = spirv.CL.fabs {{%.*}} : vector<5000xf32>
44+
%0 = spirv.CL.fabs %arg0 : vector<5000xf32>
45+
spirv.Return
46+
}
47+
4248
spirv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
4349
// CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
4450
%13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32

0 commit comments

Comments
 (0)