Skip to content

Commit 5c8963b

Browse files
delete scalable vec type
1 parent ba29cf2 commit 5c8963b

File tree

13 files changed

+29
-236
lines changed

13 files changed

+29
-236
lines changed

mlir/docs/Dialects/LLVM.md

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,18 +327,7 @@ multiple of some fixed size in case of _scalable_ vectors, and the element type.
327327
Vectors cannot be nested and only 1D vectors are supported. Scalable vectors are
328328
still considered 1D.
329329

330-
The LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in
331-
types, and provides additional types for scalable vectors of any types
332-
(`LLVMScalableVectorType`):
333-
334-
```
335-
llvm-vec-type ::= `!llvm.vec<` (`?` `x`)? integer-literal `x` type `>`
336-
```
337-
338-
Note that the sets of element types supported by built-in and LLVM dialect
339-
vector types are mutually exclusive, e.g., the built-in vector type does not
340-
accept `!llvm.ptr` and the LLVM dialect fixed-width vector type does not
341-
accept `i32`.
330+
The LLVM dialect uses built-in vector type.
342331

343332
The following functions are provided to operate on any kind of the vector types
344333
compatible with the LLVM dialect:
@@ -358,8 +347,8 @@ compatible with the LLVM dialect:
358347

359348
```mlir
360349
vector<42 x i32> // Vector of 42 32-bit integers.
361-
!llvm.vec<42 x ptr> // Vector of 42 pointers.
362-
!llvm.vec<? x 4 x i32> // Scalable vector of 32-bit integers with
350+
vector<42 x !llvm.ptr> // Vector of 42 pointers.
351+
vector<[4] x i32> // Scalable vector of 32-bit integers with
363352
// size divisible by 4.
364353
!llvm.array<2 x vector<2 x i32>> // Array of 2 vectors of 2 32-bit integers.
365354
!llvm.array<2 x vec<2 x ptr>> // Array of 2 vectors of 2 pointers.

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -288,38 +288,6 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
288288
];
289289
}
290290

291-
//===----------------------------------------------------------------------===//
292-
// LLVMScalableVectorType
293-
//===----------------------------------------------------------------------===//
294-
295-
def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
296-
let summary = "LLVM scalable vector type";
297-
let description = [{
298-
LLVM dialect scalable vector type, represents a sequence of elements of
299-
unknown length that is known to be divisible by some constant. These
300-
elements can be processed as one in SIMD context.
301-
}];
302-
303-
let typeName = "llvm.scalable_vec";
304-
305-
let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
306-
let assemblyFormat = [{
307-
`<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
308-
}];
309-
310-
let genVerifyDecl = 1;
311-
312-
let builders = [
313-
TypeBuilderWithInferredContext<(ins "Type":$elementType,
314-
"unsigned":$minNumElements)>
315-
];
316-
317-
let extraClassDeclaration = [{
318-
/// Checks if the given type can be used in a vector type.
319-
static bool isValidElementType(Type type);
320-
}];
321-
}
322-
323291
//===----------------------------------------------------------------------===//
324292
// LLVMTargetExtType
325293
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,6 @@ GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
685685
static Type extractVectorElementType(Type type) {
686686
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
687687
return vectorType.getElementType();
688-
if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
689-
return scalableVectorType.getElementType();
690688
return type;
691689
}
692690

@@ -724,10 +722,9 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
724722
continue;
725723

726724
currType = TypeSwitch<Type, Type>(currType)
727-
.Case<VectorType, LLVMScalableVectorType, LLVMArrayType>(
728-
[](auto containerType) {
729-
return containerType.getElementType();
730-
})
725+
.Case<VectorType, LLVMArrayType>([](auto containerType) {
726+
return containerType.getElementType();
727+
})
731728
.Case([&](LLVMStructType structType) -> Type {
732729
int64_t memberIndex = rawConstantIndices.back();
733730
if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
@@ -836,7 +833,7 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos,
836833
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
837834
indices, emitOpError);
838835
})
839-
.Case<VectorType, LLVMScalableVectorType, LLVMArrayType>(
836+
.Case<VectorType, LLVMArrayType>(
840837
[&](auto containerType) -> LogicalResult {
841838
return verifyStructIndices(containerType.getElementType(),
842839
indexPos + 1, indices, emitOpError);
@@ -3162,16 +3159,12 @@ static int64_t getNumElements(Type t) {
31623159
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
31633160
return arrayType.getNumElements() *
31643161
getNumElements(arrayType.getElementType());
3165-
assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
3166-
"number of elements of a scalable vector type is unknown");
31673162
return 1;
31683163
}
31693164

31703165
/// Check if the given type is a scalable vector type or a vector/array type
31713166
/// that contains a nested scalable vector type.
31723167
static bool hasScalableVectorType(Type t) {
3173-
if (isa<LLVM::LLVMScalableVectorType>(t))
3174-
return true;
31753168
if (auto vecType = dyn_cast<VectorType>(t)) {
31763169
if (vecType.isScalable())
31773170
return true;
@@ -3507,7 +3500,7 @@ LogicalResult LLVM::BitcastOp::verify() {
35073500
if (!resultType)
35083501
return success();
35093502

3510-
auto isVector = llvm::IsaPred<VectorType, LLVMScalableVectorType>;
3503+
auto isVector = llvm::IsaPred<VectorType>;
35113504

35123505
// Due to bitcast requiring both operands to be of the same size, it is not
35133506
// possible for only one of the two to be a pointer of vectors.

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,6 @@ static bool isSupportedTypeForConversion(Type type) {
134134
if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
135135
return false;
136136

137-
// LLVM vector types are only used for either pointers or target specific
138-
// types. These types cannot be casted in the general case, thus the memory
139-
// optimizations do not support them.
140-
if (isa<LLVM::LLVMScalableVectorType>(type))
141-
return false;
142-
143137
if (auto vectorType = dyn_cast<VectorType>(type)) {
144138
// Vectors of pointers cannot be casted.
145139
if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))

mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ static StringRef getTypeKeyword(Type type) {
4040
.Case<LLVMMetadataType>([&](Type) { return "metadata"; })
4141
.Case<LLVMFunctionType>([&](Type) { return "func"; })
4242
.Case<LLVMPointerType>([&](Type) { return "ptr"; })
43-
.Case<LLVMScalableVectorType>([&](Type) { return "vec"; })
4443
.Case<LLVMArrayType>([&](Type) { return "array"; })
4544
.Case<LLVMStructType>([&](Type) { return "struct"; })
4645
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
@@ -103,9 +102,8 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
103102
printer << getTypeKeyword(type);
104103

105104
llvm::TypeSwitch<Type>(type)
106-
.Case<LLVMPointerType, LLVMArrayType, LLVMScalableVectorType,
107-
LLVMFunctionType, LLVMTargetExtType, LLVMStructType>(
108-
[&](auto type) { type.print(printer); });
105+
.Case<LLVMPointerType, LLVMArrayType, LLVMFunctionType, LLVMTargetExtType,
106+
LLVMStructType>([&](auto type) { type.print(printer); });
109107
}
110108

111109
//===----------------------------------------------------------------------===//
@@ -114,41 +112,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
114112

115113
static ParseResult dispatchParse(AsmParser &parser, Type &type);
116114

117-
/// Parses an LLVM dialect vector type.
118-
/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
119-
/// Supports both fixed and scalable vectors.
120-
static Type parseVectorType(AsmParser &parser) {
121-
SmallVector<int64_t, 2> dims;
122-
SMLoc dimPos, typePos;
123-
Type elementType;
124-
SMLoc loc = parser.getCurrentLocation();
125-
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
126-
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
127-
parser.getCurrentLocation(&typePos) ||
128-
dispatchParse(parser, elementType) || parser.parseGreater())
129-
return Type();
130-
131-
// We parsed a generic dimension list, but vectors only support two forms:
132-
// - single non-dynamic entry in the list (fixed vector);
133-
// - two elements, the first dynamic (indicated by ShapedType::kDynamic)
134-
// and the second
135-
// non-dynamic (scalable vector).
136-
if (dims.empty() || dims.size() > 2 ||
137-
((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) ||
138-
(dims.size() == 2 && ShapedType::isDynamic(dims[1]))) {
139-
parser.emitError(dimPos)
140-
<< "expected '? x <integer> x <type>' or '<integer> x <type>'";
141-
return Type();
142-
}
143-
144-
bool isScalable = dims.size() == 2;
145-
if (!isScalable) {
146-
parser.emitError(dimPos) << "expected scalable vector";
147-
return Type();
148-
}
149-
return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
150-
}
151-
152115
/// Attempts to set the body of an identified structure type. Reports a parsing
153116
/// error at `subtypesLoc` in case of failure.
154117
static LLVMStructType trySetStructBody(LLVMStructType type,
@@ -307,7 +270,6 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
307270
.Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
308271
.Case("func", [&] { return LLVMFunctionType::parse(parser); })
309272
.Case("ptr", [&] { return LLVMPointerType::parse(parser); })
310-
.Case("vec", [&] { return parseVectorType(parser); })
311273
.Case("array", [&] { return LLVMArrayType::parse(parser); })
312274
.Case("struct", [&] { return LLVMStructType::parse(parser); })
313275
.Case("target", [&] { return LLVMTargetExtType::parse(parser); })

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 7 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
150150

151151
bool LLVMArrayType::isValidElementType(Type type) {
152152
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153-
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
154-
type);
153+
LLVMFunctionType, LLVMTokenType>(type);
155154
}
156155

157156
LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
@@ -659,53 +658,6 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
659658
return mlir::success();
660659
}
661660

662-
//===----------------------------------------------------------------------===//
663-
// LLVMScalableVectorType.
664-
//===----------------------------------------------------------------------===//
665-
666-
/// Verifies that the type about to be constructed is well-formed.
667-
template <typename VecTy>
668-
static LogicalResult
669-
verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
670-
Type elementType, unsigned numElements) {
671-
if (numElements == 0)
672-
return emitError() << "the number of vector elements must be positive";
673-
674-
if (!VecTy::isValidElementType(elementType))
675-
return emitError() << "invalid vector element type";
676-
677-
return success();
678-
}
679-
680-
LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
681-
unsigned minNumElements) {
682-
assert(elementType && "expected non-null subtype");
683-
return Base::get(elementType.getContext(), elementType, minNumElements);
684-
}
685-
686-
LLVMScalableVectorType
687-
LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
688-
Type elementType, unsigned minNumElements) {
689-
assert(elementType && "expected non-null subtype");
690-
return Base::getChecked(emitError, elementType.getContext(), elementType,
691-
minNumElements);
692-
}
693-
694-
bool LLVMScalableVectorType::isValidElementType(Type type) {
695-
if (auto intType = llvm::dyn_cast<IntegerType>(type))
696-
return intType.isSignless();
697-
698-
return isCompatibleFloatingPointType(type) ||
699-
llvm::isa<LLVMPointerType>(type);
700-
}
701-
702-
LogicalResult
703-
LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
704-
Type elementType, unsigned numElements) {
705-
return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
706-
emitError, elementType, numElements);
707-
}
708-
709661
//===----------------------------------------------------------------------===//
710662
// LLVMTargetExtType.
711663
//===----------------------------------------------------------------------===//
@@ -764,7 +716,6 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
764716
LLVMPointerType,
765717
LLVMStructType,
766718
LLVMTokenType,
767-
LLVMScalableVectorType,
768719
LLVMTargetExtType,
769720
LLVMVoidType,
770721
LLVMX86AMXType
@@ -812,7 +763,6 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
812763
})
813764
// clang-format off
814765
.Case<
815-
LLVMScalableVectorType,
816766
LLVMArrayType
817767
>([&](auto containerType) {
818768
return isCompatible(containerType.getElementType());
@@ -859,9 +809,6 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
859809
}
860810

861811
bool mlir::LLVM::isCompatibleVectorType(Type type) {
862-
if (llvm::isa<LLVMScalableVectorType>(type))
863-
return true;
864-
865812
if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
866813
if (vecType.getRank() != 1)
867814
return false;
@@ -876,8 +823,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
876823

877824
Type mlir::LLVM::getVectorElementType(Type type) {
878825
return llvm::TypeSwitch<Type, Type>(type)
879-
.Case<LLVMScalableVectorType, VectorType>(
880-
[](auto ty) { return ty.getElementType(); })
826+
.Case<VectorType>([](auto ty) { return ty.getElementType(); })
881827
.Default([](Type) -> Type {
882828
llvm_unreachable("incompatible with LLVM vector type");
883829
});
@@ -890,37 +836,22 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
890836
return llvm::ElementCount::getScalable(ty.getNumElements());
891837
return llvm::ElementCount::getFixed(ty.getNumElements());
892838
})
893-
.Case([](LLVMScalableVectorType ty) {
894-
return llvm::ElementCount::getScalable(ty.getMinNumElements());
895-
})
896839
.Default([](Type) -> llvm::ElementCount {
897840
llvm_unreachable("incompatible with LLVM vector type");
898841
});
899842
}
900843

901844
bool mlir::LLVM::isScalableVectorType(Type vectorType) {
902-
assert((llvm::isa<LLVMScalableVectorType, VectorType>(vectorType)) &&
845+
assert(llvm::isa<VectorType>(vectorType) &&
903846
"expected LLVM-compatible vector type");
904-
return llvm::isa<LLVMScalableVectorType>(vectorType) ||
905-
llvm::cast<VectorType>(vectorType).isScalable();
847+
return llvm::cast<VectorType>(vectorType).isScalable();
906848
}
907849

908850
Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
909851
bool isScalable) {
910-
if (!isScalable) {
911-
// Non-scalable vectors always use the MLIR vector type.
912-
assert(VectorType::isValidElementType(elementType) &&
913-
"incompatible element type");
914-
return VectorType::get(numElements, elementType, {false});
915-
}
916-
917-
// This is a scalable vector.
918-
if (VectorType::isValidElementType(elementType))
919-
return VectorType::get(numElements, elementType, {true});
920-
assert(LLVMScalableVectorType::isValidElementType(elementType) &&
921-
"neither the MLIR vector type nor LLVMScalableVectorType is "
922-
"compatible with the specified element type");
923-
return LLVMScalableVectorType::get(elementType, numElements);
852+
assert(VectorType::isValidElementType(elementType) &&
853+
"incompatible element type");
854+
return VectorType::get(numElements, elementType, {isScalable});
924855
}
925856

926857
Type mlir::LLVM::getVectorType(Type elementType,
@@ -939,15 +870,6 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
939870
}
940871

941872
Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
942-
bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
943-
bool useBuiltIn = VectorType::isValidElementType(elementType);
944-
(void)useBuiltIn;
945-
assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
946-
"type to be either builtin or LLVM dialect "
947-
"type");
948-
if (useLLVM)
949-
return LLVMScalableVectorType::get(elementType, numElements);
950-
951873
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
952874
// scalable/non-scalable.
953875
return VectorType::get(numElements, elementType, /*scalableDims=*/true);

mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ class TypeFromLLVMIRTranslatorImpl {
130130

131131
/// Translates the given scalable-vector type.
132132
Type translate(llvm::ScalableVectorType *type) {
133-
return LLVM::LLVMScalableVectorType::get(
134-
translateType(type->getElementType()), type->getMinNumElements());
133+
return LLVM::getScalableVectorType(translateType(type->getElementType()),
134+
type->getMinNumElements());
135135
}
136136

137137
/// Translates the given target extension type.

0 commit comments

Comments
 (0)