Skip to content

Commit cbb24e1

Browse files
[LLVM][IR] Add native vector support to ConstantInt & ConstantFP. (#74502)
NOTE: For brevity the following talks about ConstantInt but everything extends to cover ConstantFP as well. Whilst ConstantInt::get() supports the creation of vectors whereby each lane has the same value, it achieves this via other constants: * ConstantVector for fixed-length vectors * ConstantExprs for scalable vectors However, ConstantExprs are being deprecated and ConstantVector is not space efficient for larger vector types. By extending ConstantInt we can represent vector splats by only storing the underlying scalar value. More specifically: * ConstantInt gains an ElementCount variant of get(). * LLVMContext is extended to map <EC,APInt>->ConstantInt. * BitcodeReader/Writer support is extended to allow vector types. Whilst this patch adds the base support, more work is required before it's production ready. For example, there's likely to be many places where isa<ConstantInt> assumes a scalar type. Accordingly the default behaviour of ConstantInt::get() remains unchanged but a set of flags are added to allow wider testing and thus help with the migration: --use-constant-int-for-fixed-length-splat --use-constant-fp-for-fixed-length-splat --use-constant-int-for-scalable-splat --use-constant-fp-for-scalable-splat NOTE: No change is required to the bitcode format because types and values are handled separately. NOTE: For similar reasons as above, code generation doesn't work out-the-box.
1 parent 5b8e560 commit cbb24e1

File tree

8 files changed

+243
-39
lines changed

8 files changed

+243
-39
lines changed

llvm/include/llvm/IR/Constants.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ class ConstantData : public Constant {
7878
/// Class for constant integers.
7979
class ConstantInt final : public ConstantData {
8080
friend class Constant;
81+
friend class ConstantVector;
8182

8283
APInt Val;
8384

84-
ConstantInt(IntegerType *Ty, const APInt &V);
85+
ConstantInt(Type *Ty, const APInt &V);
8586

8687
void destroyConstantImpl();
8788

89+
/// Return a ConstantInt with the specified value and an implied Type. The
90+
/// type is the vector type whose integer element type corresponds to the bit
91+
/// width of the value.
92+
static ConstantInt *get(LLVMContext &Context, ElementCount EC,
93+
const APInt &V);
94+
8895
public:
8996
ConstantInt(const ConstantInt &) = delete;
9097

@@ -136,7 +143,7 @@ class ConstantInt final : public ConstantData {
136143
/// Return the constant's value.
137144
inline const APInt &getValue() const { return Val; }
138145

139-
/// getBitWidth - Return the bitwidth of this constant.
146+
/// getBitWidth - Return the scalar bitwidth of this constant.
140147
unsigned getBitWidth() const { return Val.getBitWidth(); }
141148

142149
/// Return the constant as a 64-bit unsigned integer value after it
@@ -259,13 +266,20 @@ class ConstantInt final : public ConstantData {
259266
///
260267
class ConstantFP final : public ConstantData {
261268
friend class Constant;
269+
friend class ConstantVector;
262270

263271
APFloat Val;
264272

265273
ConstantFP(Type *Ty, const APFloat &V);
266274

267275
void destroyConstantImpl();
268276

277+
/// Return a ConstantFP with the specified value and an implied Type. The
278+
/// type is the vector type whose element type has the same floating point
279+
/// semantics as the value.
280+
static ConstantFP *get(LLVMContext &Context, ElementCount EC,
281+
const APFloat &V);
282+
269283
public:
270284
ConstantFP(const ConstantFP &) = delete;
271285

llvm/lib/Bitcode/Reader/BitcodeReader.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,48 +3060,49 @@ Error BitcodeReader::parseConstants() {
30603060
V = Constant::getNullValue(CurTy);
30613061
break;
30623062
case bitc::CST_CODE_INTEGER: // INTEGER: [intval]
3063-
if (!CurTy->isIntegerTy() || Record.empty())
3063+
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
30643064
return error("Invalid integer const record");
30653065
V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
30663066
break;
30673067
case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
3068-
if (!CurTy->isIntegerTy() || Record.empty())
3068+
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
30693069
return error("Invalid wide integer const record");
30703070

3071-
APInt VInt =
3072-
readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
3073-
V = ConstantInt::get(Context, VInt);
3074-
3071+
auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
3072+
APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
3073+
V = ConstantInt::get(CurTy, VInt);
30753074
break;
30763075
}
30773076
case bitc::CST_CODE_FLOAT: { // FLOAT: [fpval]
30783077
if (Record.empty())
30793078
return error("Invalid float const record");
3080-
if (CurTy->isHalfTy())
3081-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
3082-
APInt(16, (uint16_t)Record[0])));
3083-
else if (CurTy->isBFloatTy())
3084-
V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
3085-
APInt(16, (uint32_t)Record[0])));
3086-
else if (CurTy->isFloatTy())
3087-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
3088-
APInt(32, (uint32_t)Record[0])));
3089-
else if (CurTy->isDoubleTy())
3090-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
3091-
APInt(64, Record[0])));
3092-
else if (CurTy->isX86_FP80Ty()) {
3079+
3080+
auto *ScalarTy = CurTy->getScalarType();
3081+
if (ScalarTy->isHalfTy())
3082+
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
3083+
APInt(16, (uint16_t)Record[0])));
3084+
else if (ScalarTy->isBFloatTy())
3085+
V = ConstantFP::get(
3086+
CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0])));
3087+
else if (ScalarTy->isFloatTy())
3088+
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(),
3089+
APInt(32, (uint32_t)Record[0])));
3090+
else if (ScalarTy->isDoubleTy())
3091+
V = ConstantFP::get(
3092+
CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0])));
3093+
else if (ScalarTy->isX86_FP80Ty()) {
30933094
// Bits are not stored the same way as a normal i80 APInt, compensate.
30943095
uint64_t Rearrange[2];
30953096
Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
30963097
Rearrange[1] = Record[0] >> 48;
3097-
V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
3098-
APInt(80, Rearrange)));
3099-
} else if (CurTy->isFP128Ty())
3100-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
3101-
APInt(128, Record)));
3102-
else if (CurTy->isPPC_FP128Ty())
3103-
V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
3104-
APInt(128, Record)));
3098+
V = ConstantFP::get(
3099+
CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange)));
3100+
} else if (ScalarTy->isFP128Ty())
3101+
V = ConstantFP::get(CurTy,
3102+
APFloat(APFloat::IEEEquad(), APInt(128, Record)));
3103+
else if (ScalarTy->isPPC_FP128Ty())
3104+
V = ConstantFP::get(
3105+
CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record)));
31053106
else
31063107
V = UndefValue::get(CurTy);
31073108
break;

llvm/lib/Bitcode/Writer/BitcodeWriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2624,7 +2624,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
26242624
}
26252625
} else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
26262626
Code = bitc::CST_CODE_FLOAT;
2627-
Type *Ty = CFP->getType();
2627+
Type *Ty = CFP->getType()->getScalarType();
26282628
if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
26292629
Ty->isDoubleTy()) {
26302630
Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());

llvm/lib/IR/AsmWriter.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,16 +1505,39 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) {
15051505
static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
15061506
AsmWriterContext &WriterCtx) {
15071507
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
1508-
if (CI->getType()->isIntegerTy(1)) {
1509-
Out << (CI->getZExtValue() ? "true" : "false");
1510-
return;
1508+
Type *Ty = CI->getType();
1509+
1510+
if (Ty->isVectorTy()) {
1511+
Out << "splat (";
1512+
WriterCtx.TypePrinter->print(Ty->getScalarType(), Out);
1513+
Out << " ";
15111514
}
1512-
Out << CI->getValue();
1515+
1516+
if (Ty->getScalarType()->isIntegerTy(1))
1517+
Out << (CI->getZExtValue() ? "true" : "false");
1518+
else
1519+
Out << CI->getValue();
1520+
1521+
if (Ty->isVectorTy())
1522+
Out << ")";
1523+
15131524
return;
15141525
}
15151526

15161527
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
1528+
Type *Ty = CFP->getType();
1529+
1530+
if (Ty->isVectorTy()) {
1531+
Out << "splat (";
1532+
WriterCtx.TypePrinter->print(Ty->getScalarType(), Out);
1533+
Out << " ";
1534+
}
1535+
15171536
WriteAPFloatInternal(Out, CFP->getValueAPF());
1537+
1538+
if (Ty->isVectorTy())
1539+
Out << ")";
1540+
15181541
return;
15191542
}
15201543

llvm/lib/IR/Constants.cpp

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@
3535
using namespace llvm;
3636
using namespace PatternMatch;
3737

38+
// As set of temporary options to help migrate how splats are represented.
39+
static cl::opt<bool> UseConstantIntForFixedLengthSplat(
40+
"use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
41+
cl::desc("Use ConstantInt's native fixed-length vector splat support."));
42+
static cl::opt<bool> UseConstantFPForFixedLengthSplat(
43+
"use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
44+
cl::desc("Use ConstantFP's native fixed-length vector splat support."));
45+
static cl::opt<bool> UseConstantIntForScalableSplat(
46+
"use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
47+
cl::desc("Use ConstantInt's native scalable vector splat support."));
48+
static cl::opt<bool> UseConstantFPForScalableSplat(
49+
"use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
50+
cl::desc("Use ConstantFP's native scalable vector splat support."));
51+
3852
//===----------------------------------------------------------------------===//
3953
// Constant Class
4054
//===----------------------------------------------------------------------===//
@@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
825839
// ConstantInt
826840
//===----------------------------------------------------------------------===//
827841

828-
ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
842+
ConstantInt::ConstantInt(Type *Ty, const APInt &V)
829843
: ConstantData(Ty, ConstantIntVal), Val(V) {
830-
assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
844+
assert(V.getBitWidth() ==
845+
cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
846+
"Invalid constant for type");
831847
}
832848

833849
ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
@@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
885901
return Slot.get();
886902
}
887903

904+
// Get a ConstantInt vector with each lane set to the same APInt.
905+
ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
906+
const APInt &V) {
907+
// Get an existing value or the insertion position.
908+
std::unique_ptr<ConstantInt> &Slot =
909+
Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
910+
if (!Slot) {
911+
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
912+
VectorType *VTy = VectorType::get(ITy, EC);
913+
Slot.reset(new ConstantInt(VTy, V));
914+
}
915+
916+
#ifndef NDEBUG
917+
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
918+
VectorType *VTy = VectorType::get(ITy, EC);
919+
assert(Slot->getType() == VTy);
920+
#endif
921+
return Slot.get();
922+
}
923+
888924
Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
889925
Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);
890926

@@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
10241060
return Slot.get();
10251061
}
10261062

1063+
// Get a ConstantFP vector with each lane set to the same APFloat.
1064+
ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
1065+
const APFloat &V) {
1066+
// Get an existing value or the insertion position.
1067+
std::unique_ptr<ConstantFP> &Slot =
1068+
Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
1069+
if (!Slot) {
1070+
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
1071+
VectorType *VTy = VectorType::get(EltTy, EC);
1072+
Slot.reset(new ConstantFP(VTy, V));
1073+
}
1074+
1075+
#ifndef NDEBUG
1076+
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
1077+
VectorType *VTy = VectorType::get(EltTy, EC);
1078+
assert(Slot->getType() == VTy);
1079+
#endif
1080+
return Slot.get();
1081+
}
1082+
10271083
Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
10281084
const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
10291085
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
@@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
10361092

10371093
ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
10381094
: ConstantData(Ty, ConstantFPVal), Val(V) {
1039-
assert(&V.getSemantics() == &Ty->getFltSemantics() &&
1095+
assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
10401096
"FP type Mismatch");
10411097
}
10421098

@@ -1356,11 +1412,13 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
13561412
bool isZero = C->isNullValue();
13571413
bool isUndef = isa<UndefValue>(C);
13581414
bool isPoison = isa<PoisonValue>(C);
1415+
bool isSplatFP = UseConstantFPForFixedLengthSplat && isa<ConstantFP>(C);
1416+
bool isSplatInt = UseConstantIntForFixedLengthSplat && isa<ConstantInt>(C);
13591417

1360-
if (isZero || isUndef) {
1418+
if (isZero || isUndef || isSplatFP || isSplatInt) {
13611419
for (unsigned i = 1, e = V.size(); i != e; ++i)
13621420
if (V[i] != C) {
1363-
isZero = isUndef = isPoison = false;
1421+
isZero = isUndef = isPoison = isSplatFP = isSplatInt = false;
13641422
break;
13651423
}
13661424
}
@@ -1371,6 +1429,12 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
13711429
return PoisonValue::get(T);
13721430
if (isUndef)
13731431
return UndefValue::get(T);
1432+
if (isSplatFP)
1433+
return ConstantFP::get(C->getContext(), T->getElementCount(),
1434+
cast<ConstantFP>(C)->getValue());
1435+
if (isSplatInt)
1436+
return ConstantInt::get(C->getContext(), T->getElementCount(),
1437+
cast<ConstantInt>(C)->getValue());
13741438

13751439
// Check to see if all of the elements are ConstantFP or ConstantInt and if
13761440
// the element type is compatible with ConstantDataVector. If so, use it.
@@ -1384,6 +1448,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
13841448

13851449
Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
13861450
if (!EC.isScalable()) {
1451+
// Maintain special handling of zero.
1452+
if (!V->isNullValue()) {
1453+
if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
1454+
return ConstantInt::get(V->getContext(), EC,
1455+
cast<ConstantInt>(V)->getValue());
1456+
if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
1457+
return ConstantFP::get(V->getContext(), EC,
1458+
cast<ConstantFP>(V)->getValue());
1459+
}
1460+
13871461
// If this splat is compatible with ConstantDataVector, use it instead of
13881462
// ConstantVector.
13891463
if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
@@ -1394,6 +1468,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
13941468
return get(Elts);
13951469
}
13961470

1471+
// Maintain special handling of zero.
1472+
if (!V->isNullValue()) {
1473+
if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
1474+
return ConstantInt::get(V->getContext(), EC,
1475+
cast<ConstantInt>(V)->getValue());
1476+
if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
1477+
return ConstantFP::get(V->getContext(), EC,
1478+
cast<ConstantFP>(V)->getValue());
1479+
}
1480+
13971481
Type *VTy = VectorType::get(V->getType(), EC);
13981482

13991483
if (V->isNullValue())

llvm/lib/IR/LLVMContextImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
119119
IntZeroConstants.clear();
120120
IntOneConstants.clear();
121121
IntConstants.clear();
122+
IntSplatConstants.clear();
122123
FPConstants.clear();
124+
FPSplatConstants.clear();
123125
CDSConstants.clear();
124126

125127
// Destroy attribute node lists.

llvm/lib/IR/LLVMContextImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,8 +1488,12 @@ class LLVMContextImpl {
14881488
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
14891489
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
14901490
DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
1491+
DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
1492+
IntSplatConstants;
14911493

14921494
DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
1495+
DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
1496+
FPSplatConstants;
14931497

14941498
FoldingSet<AttributeImpl> AttrsSet;
14951499
FoldingSet<AttributeListImpl> AttrsLists;

0 commit comments

Comments
 (0)