Skip to content

Commit 428b9be

Browse files
authored
[mlir] Align num elements type to LLVM ArrayType (#93230)
MLIR LLMArrayType is using `unsigned` for the number of elements while LLVM ArrayType is using `uint64_t` https://github.com/llvm/llvm-project/blob/4ae896fe979b7db501cabde4b6b3504478958682/llvm/include/llvm/IR/DerivedTypes.h#L377 This leads to silent truncation when we use it for globals in flang. ``` program test integer(8), parameter :: large = 2**30 real, dimension(large) :: bigarray common /c/ bigarray bigarray(999) = 666 end ``` The above program would result in a segfault since the global would be of size 0 because of the silent truncation. ``` fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8> ``` became ``` llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<0 x i8> ``` This patch updates the definition of MLIR ArrayType to take `uint64_t` as argument of the number of elements to be compatible with LLVM.
1 parent 4e67f45 commit 428b9be

File tree

5 files changed

+53
-7
lines changed

5 files changed

+53
-7
lines changed

flang/test/Fir/convert-to-llvm.fir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2699,3 +2699,9 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
26992699
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
27002700
// CHECK: llvm.return
27012701
// CHECK: }
2702+
2703+
// -----
2704+
2705+
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
2706+
2707+
// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
4040
```
4141
}];
4242

43-
let parameters = (ins "Type":$elementType, "unsigned":$numElements);
43+
let parameters = (ins "Type":$elementType, "uint64_t":$numElements);
4444
let assemblyFormat = [{
4545
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
4646
}];
@@ -49,7 +49,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
4949

5050
let builders = [
5151
TypeBuilderWithInferredContext<(ins "Type":$elementType,
52-
"unsigned":$numElements)>
52+
"uint64_t":$numElements)>
5353
];
5454

5555
let extraClassDeclaration = [{

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,22 +154,22 @@ bool LLVMArrayType::isValidElementType(Type type) {
154154
type);
155155
}
156156

157-
LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
157+
LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
158158
assert(elementType && "expected non-null subtype");
159159
return Base::get(elementType.getContext(), elementType, numElements);
160160
}
161161

162162
LLVMArrayType
163163
LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
164-
Type elementType, unsigned numElements) {
164+
Type elementType, uint64_t numElements) {
165165
assert(elementType && "expected non-null subtype");
166166
return Base::getChecked(emitError, elementType.getContext(), elementType,
167167
numElements);
168168
}
169169

170170
LogicalResult
171171
LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
172-
Type elementType, unsigned numElements) {
172+
Type elementType, uint64_t numElements) {
173173
if (!isValidElementType(elementType))
174174
return emitError() << "invalid array element type: " << elementType;
175175
return success();

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,43 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
632632
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
633633
if (llvmType->isArrayTy()) {
634634
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
635-
SmallVector<llvm::Constant *, 8> constants(numElements, child);
636-
return llvm::ConstantArray::get(arrayType, constants);
635+
if (child->isZeroValue()) {
636+
return llvm::ConstantAggregateZero::get(arrayType);
637+
} else {
638+
if (llvm::ConstantDataSequential::isElementTypeCompatible(
639+
elementType)) {
640+
// TODO: Handle all compatible types. This code only handles integer.
641+
if (llvm::IntegerType *iTy =
642+
dyn_cast<llvm::IntegerType>(elementType)) {
643+
if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
644+
if (ci->getBitWidth() == 8) {
645+
SmallVector<int8_t> constants(numElements, ci->getZExtValue());
646+
return llvm::ConstantDataArray::get(elementType->getContext(),
647+
constants);
648+
}
649+
if (ci->getBitWidth() == 16) {
650+
SmallVector<int16_t> constants(numElements, ci->getZExtValue());
651+
return llvm::ConstantDataArray::get(elementType->getContext(),
652+
constants);
653+
}
654+
if (ci->getBitWidth() == 32) {
655+
SmallVector<int32_t> constants(numElements, ci->getZExtValue());
656+
return llvm::ConstantDataArray::get(elementType->getContext(),
657+
constants);
658+
}
659+
if (ci->getBitWidth() == 64) {
660+
SmallVector<int64_t> constants(numElements, ci->getZExtValue());
661+
return llvm::ConstantDataArray::get(elementType->getContext(),
662+
constants);
663+
}
664+
}
665+
}
666+
}
667+
// std::vector is used here to accomodate large number of elements that
668+
// exceed SmallVector capacity.
669+
std::vector<llvm::Constant *> constants(numElements, child);
670+
return llvm::ConstantArray::get(arrayType, constants);
671+
}
637672
}
638673
}
639674

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,3 +2396,8 @@ llvm.func @zeroinit_complex_local_aggregate() {
23962396
llvm.linker_options ["/DEFAULTLIB:", "libcmt"]
23972397
//CHECK: ![[MD1]] = !{!"/DEFAULTLIB:", !"libcmtd"}
23982398
llvm.linker_options ["/DEFAULTLIB:", "libcmtd"]
2399+
2400+
// -----
2401+
2402+
// CHECK: @big_ = common global [4294967296 x i8] zeroinitializer
2403+
llvm.mlir.global common @big_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>

0 commit comments

Comments
 (0)