Skip to content

Commit ddf2d62

Browse files
michaltnicolasvasilache
authored andcommitted
[mlir][Vector] First step for 0D vector type
There seems to be a consensus that we should allow 0D vectors: https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097 This commit is only the first step: it changes the verifier and the parser to allow vectors like `vector<f32>` (but does not allow explicit 0 dimensions, i.e., `vector<0xf32>` is not allowed). Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114086
1 parent 35ff3a0 commit ddf2d62

File tree

5 files changed

+7
-14
lines changed

5 files changed

+7
-14
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -895,21 +895,21 @@ def Builtin_Vector : Builtin_Type<"Vector", [
895895
vector-type ::= `vector` `<` static-dimension-list vector-element-type `>`
896896
vector-element-type ::= float-type | integer-type | index-type
897897

898-
static-dimension-list ::= (decimal-literal `x`)+
898+
static-dimension-list ::= (decimal-literal `x`)*
899899
```
900900

901901
The vector type represents a SIMD style vector, used by target-specific
902902
operation sets like AVX. While the most common use is for 1D vectors (e.g.
903903
vector<16 x f32>) we also support multidimensional registers on targets that
904904
support them (like TPUs).
905905

906-
Vector shapes must be positive decimal integers.
906+
Vector shapes must be positive decimal integers. 0D vectors are allowed by
907+
omitting the dimension: `vector<f32>`.
907908

908909
Note: hexadecimal integer literals are not allowed in vector type
909910
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
910911
2D vector with shape `(0, 42)` and zero shapes are not allowed.
911912

912-
913913
Examples:
914914

915915
```mlir

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,6 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
441441

442442
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
443443
ArrayRef<int64_t> shape, Type elementType) {
444-
if (shape.empty())
445-
return emitError() << "vector types must have at least one dimension";
446-
447444
if (!isValidElementType(elementType))
448445
return emitError()
449446
<< "vector elements must be int/index/float type but got "

mlir/lib/Parser/TypeParser.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,7 @@ Type Parser::parseTupleType() {
442442

443443
/// Parse a vector type.
444444
///
445-
/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
446-
/// non-empty-static-dimension-list ::= decimal-literal `x`
447-
/// static-dimension-list
445+
/// vector-type ::= `vector` `<` static-dimension-list type `>`
448446
/// static-dimension-list ::= (decimal-literal `x`)*
449447
///
450448
VectorType Parser::parseVectorType() {
@@ -456,8 +454,6 @@ VectorType Parser::parseVectorType() {
456454
SmallVector<int64_t, 4> dimensions;
457455
if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
458456
return nullptr;
459-
if (dimensions.empty())
460-
return (emitError("expected dimension size in vector type"), nullptr);
461457
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
462458
return emitError(getToken().getLoc(),
463459
"vector types must have positive constant sizes"),

mlir/test/IR/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ func @zero_in_vector_type() -> vector<1x0xi32>
949949
950950
// -----
951951
952-
// expected-error @+1 {{expected dimension size in vector type}}
952+
// expected-error @+1 {{expected non-function type}}
953953
func @negative_vector_size() -> vector<-1xi32>
954954
955955
// -----

mlir/test/IR/parser.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ func private @uint_types(ui2, ui4) -> (ui7, ui1023)
6767
// CHECK: func private @float_types(f80, f128)
6868
func private @float_types(f80, f128)
6969

70-
// CHECK: func private @vectors(vector<1xf32>, vector<2x4xf32>)
71-
func private @vectors(vector<1 x f32>, vector<2x4xf32>)
70+
// CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
71+
func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)
7272

7373
// CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
7474
func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,

0 commit comments

Comments
 (0)