Skip to content

Commit 22ed9e0

Browse files
[mlir][LLVM] Delete getVectorElementType
1 parent a00a61d commit 22ed9e0

File tree

11 files changed

+44
-47
lines changed

11 files changed

+44
-47
lines changed

mlir/docs/Dialects/LLVM.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ compatible with the LLVM dialect:
334334

335335
- `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a
336336
vector type compatible with the LLVM dialect;
337-
- `Type LLVM::getVectorElementType(Type)` - returns the element type of any
338-
vector type compatible with the LLVM dialect;
339337
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
340338
of elements in any vector type compatible with the LLVM dialect;
341339
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
874874
const llvm::DataLayout &dl =
875875
builder.GetInsertBlock()->getModule()->getDataLayout();
876876
llvm::Type *ElemTy = moduleTranslation.convertType(
877-
getVectorElementType(op.getType()));
877+
op.getType().getElementType());
878878
llvm::Align align = dl.getABITypeAlign(ElemTy);
879879
$res = mb.CreateColumnMajorLoad(
880880
ElemTy, $data, align, $stride, $isVolatile, $rows,
@@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
907907
llvm::MatrixBuilder mb(builder);
908908
const llvm::DataLayout &dl =
909909
builder.GetInsertBlock()->getModule()->getDataLayout();
910-
Type elementType = getVectorElementType(op.getMatrix().getType());
910+
Type elementType = op.getMatrix().getType().getElementType();
911911
llvm::Align align = dl.getABITypeAlign(
912912
moduleTranslation.convertType(elementType));
913913
mb.CreateColumnMajorStore(
@@ -1164,7 +1164,8 @@ def LLVM_vector_insert
11641164
let extraClassDeclaration = [{
11651165
uint64_t getVectorBitWidth(Type vector) {
11661166
return getVectorNumElements(vector).getKnownMinValue() *
1167-
getVectorElementType(vector).getIntOrFloatBitWidth();
1167+
::llvm::cast<VectorType>(vector).getElementType()
1168+
.getIntOrFloatBitWidth();
11681169
}
11691170
uint64_t getSrcVectorBitWidth() {
11701171
return getVectorBitWidth(getSrcvec().getType());
@@ -1196,7 +1197,8 @@ def LLVM_vector_extract
11961197
let extraClassDeclaration = [{
11971198
uint64_t getVectorBitWidth(Type vector) {
11981199
return getVectorNumElements(vector).getKnownMinValue() *
1199-
getVectorElementType(vector).getIntOrFloatBitWidth();
1200+
::llvm::cast<VectorType>(vector).getElementType()
1201+
.getIntOrFloatBitWidth();
12001202
}
12011203
uint64_t getSrcVectorBitWidth() {
12021204
return getVectorBitWidth(getSrcvec().getType());
@@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
12161218
"result has twice as many elements as 'vec1'",
12171219
And<[CPred<"getVectorNumElements($res.getType()) == "
12181220
"getVectorNumElements($vec1.getType()) * 2">,
1219-
CPred<"getVectorElementType($vec1.getType()) == "
1220-
"getVectorElementType($res.getType())">]>>,
1221+
CPred<"::llvm::cast<VectorType>($vec1.getType()).getElementType() == "
1222+
"::llvm::cast<VectorType>($res.getType()).getElementType()">]>>,
12211223
]>,
12221224
Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
12231225

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,31 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
113113

114114
// Type constraint accepting any LLVM vector type.
115115
def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
116-
"LLVM dialect-compatible vector type">;
116+
"LLVM dialect-compatible vector type",
117+
"::mlir::VectorType">;
117118

118119
// Type constraint accepting any LLVM fixed-length vector type.
119120
def LLVM_AnyFixedVector : Type<CPred<
120121
"!::mlir::LLVM::isScalableVectorType($_self)">,
121-
"LLVM dialect-compatible fixed-length vector type">;
122+
"LLVM dialect-compatible fixed-length vector type",
123+
"::mlir::VectorType">;
122124

123125
// Type constraint accepting any LLVM scalable vector type.
124126
def LLVM_AnyScalableVector : Type<CPred<
125127
"::mlir::LLVM::isScalableVectorType($_self)">,
126-
"LLVM dialect-compatible scalable vector type">;
128+
"LLVM dialect-compatible scalable vector type",
129+
"::mlir::VectorType">;
127130

128131
// Type constraint accepting an LLVM vector type with an additional constraint
129132
// on the vector element type.
130133
class LLVM_VectorOf<Type element> : Type<
131134
And<[LLVM_AnyVector.predicate,
132135
SubstLeaves<
133136
"$_self",
134-
"::mlir::LLVM::getVectorElementType($_self)",
137+
"::llvm::cast<::mlir::VectorType>($_self).getElementType()",
135138
element.predicate>]>,
136-
"LLVM dialect-compatible vector of " # element.summary>;
139+
"LLVM dialect-compatible vector of " # element.summary,
140+
"::mlir::VectorType">;
137141

138142
// Type constraint accepting a constrained type, or a vector of such types.
139143
class LLVM_ScalarOrVectorOf<Type element> :

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
820820
//===----------------------------------------------------------------------===//
821821

822822
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
823-
TypesMatchWith<"result type matches vector element type", "vector", "res",
824-
"LLVM::getVectorElementType($_self)">]> {
823+
TypesMatchWith<
824+
"result type matches vector element type", "vector", "res",
825+
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> {
825826
let summary = "Extract an element from an LLVM vector.";
826827

827828
let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
@@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> {
881882

882883
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
883884
TypesMatchWith<"argument type matches vector element type", "vector",
884-
"value", "LLVM::getVectorElementType($_self)">,
885+
"value",
886+
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">,
885887
AllTypesMatch<["res", "vector"]>]> {
886888
let summary = "Insert an element into an LLVM vector.";
887889

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type);
111111
/// dialect pointers and LLVM dialect scalable vector types.
112112
bool isCompatibleVectorType(Type type);
113113

114-
/// Returns the element type of any vector type compatible with the LLVM
115-
/// dialect.
116-
Type getVectorElementType(Type type);
117-
118114
/// Returns the element count of any LLVM-compatible vector type.
119115
llvm::ElementCount getVectorNumElements(Type type);
120116

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) {
7878

7979
/// Returns the bit width of LLVMType integer or vector.
8080
static unsigned getLLVMTypeBitWidth(Type type) {
81-
return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
82-
? LLVM::getVectorElementType(type)
83-
: type))
84-
.getWidth();
81+
if (auto vecTy = dyn_cast<VectorType>(type))
82+
type = vecTy.getElementType();
83+
return cast<IntegerType>(type).getWidth();
8584
}
8685

8786
/// Creates `IntegerAttribute` with all bits set for given type

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
27342734
Value v2, DenseI32ArrayAttr mask,
27352735
ArrayRef<NamedAttribute> attrs) {
27362736
auto containerType = v1.getType();
2737-
auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
2738-
mask.size(),
2739-
LLVM::isScalableVectorType(containerType));
2737+
auto vType = LLVM::getVectorType(
2738+
cast<VectorType>(containerType).getElementType(), mask.size(),
2739+
LLVM::isScalableVectorType(containerType));
27402740
build(builder, state, vType, v1, v2, mask);
27412741
state.addAttributes(attrs);
27422742
}
@@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
27522752
if (!LLVM::isCompatibleVectorType(v1Type))
27532753
return parser.emitError(parser.getCurrentLocation(),
27542754
"expected an LLVM compatible vector type");
2755-
resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
2756-
LLVM::isScalableVectorType(v1Type));
2755+
resType =
2756+
LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
2757+
mask.size(), LLVM::isScalableVectorType(v1Type));
27572758
return success();
27582759
}
27592760

@@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() {
33183319
if (isCompatibleVectorType(valType)) {
33193320
if (isScalableVectorType(valType))
33203321
return emitOpError("expected LLVM IR fixed vector type");
3321-
Type elemType = getVectorElementType(valType);
3322+
Type elemType = llvm::cast<VectorType>(valType).getElementType();
33223323
if (!isCompatibleFloatingPointType(elemType))
33233324
return emitOpError(
33243325
"expected LLVM IR floating point type for vector element");
@@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) {
34233424
return op.emitError("input and output vectors are of incompatible shape");
34243425
// Because this is a CastOp, the element of vectors is guaranteed to be an
34253426
// integer.
3426-
inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
3427-
outputType =
3428-
cast<IntegerType>(getVectorElementType(op.getResult().getType()));
3427+
inputType = cast<IntegerType>(
3428+
cast<VectorType>(op.getArg().getType()).getElementType());
3429+
outputType = cast<IntegerType>(
3430+
cast<VectorType>(op.getResult().getType()).getElementType());
34293431
} else {
34303432
// Because this is a CastOp and arg is not a vector, arg is guaranteed to be
34313433
// an integer.

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
821821
return false;
822822
}
823823

824-
Type mlir::LLVM::getVectorElementType(Type type) {
825-
auto vecTy = dyn_cast<VectorType>(type);
826-
assert(vecTy && "incompatible with LLVM vector type");
827-
return vecTy.getElementType();
828-
}
829-
830824
llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
831825
auto vecTy = dyn_cast<VectorType>(type);
832826
assert(vecTy && "incompatible with LLVM vector type");

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
826826
}
827827

828828
// An LLVM dialect vector can only contain scalars.
829-
Type elementType = LLVM::getVectorElementType(type);
829+
Type elementType = cast<VectorType>(type).getElementType();
830830
if (!elementType.isIntOrFloat())
831831
return {};
832832

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() {
515515
// -----
516516

517517
func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
518-
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
518+
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
519519
%0 = llvm.extractelement %arg2[%arg1 : i32] : f32
520520
}
521521

522522
// -----
523523

524524
func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
525-
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
525+
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
526526
%0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
527527
}
528528

529529
// -----
530530

531531
func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
532-
// expected-error@+2 {{expected an LLVM compatible vector type}}
532+
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
533533
%0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32
534534
}
535535

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
211211
// -----
212212

213213
llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 {
214-
// expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
214+
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
215215
%0 = llvm.intr.matrix.column.major.load %ptr, <stride=%stride>
216216
{ isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32
217217
llvm.return %0 : f32
@@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s
229229
// -----
230230

231231
llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> {
232-
// expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}}
232+
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
233233
%0 = llvm.intr.matrix.multiply %arg0, %arg1
234234
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32>
235235
llvm.return %0 : vector<12xf32>
@@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32)
238238
// -----
239239

240240
llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> {
241-
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
241+
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
242242
%0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32>
243243
llvm.return %0 : vector<48xf32>
244244
}
@@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, %
286286
// -----
287287

288288
llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) {
289-
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
289+
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
290290
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr>
291291
llvm.return
292292
}

0 commit comments

Comments
 (0)