Skip to content

Commit 6c6baf8

Browse files
[LLVM][IR] Add native vector support to ConstantInt & ConstantFP.
NOTE: For brevity the following talks about ConstantInt but everything extends to cover ConstantFP as well. Whilst ConstantInt::get() supports the creation of vectors whereby each lane has the same value, it achieves this via other constants: * ConstantVector for fixed-length vectors * ConstantExprs for scalable vectors However, ConstantExprs are being deprecated and ConstantVector is not space efficient for larger vector types. By extending ConstantInt we can represent vector splats by only storing the underlying scalar value. More specifically: * ConstantInt gains an ElementCount variant of get(). * LLVMContext is extended to map <EC,APInt>->ConstantInt. * BitcodeReader/Writer support is extended to allow vector types. Whilst this patch adds the base support, more work is required before it's production ready. For example, there's likely to be many places where isa<ConstantInt> assumes a scalar type. Accordingly the default behaviour of ConstantInt::get() remains unchanged but a set of flags are added to allow wider testing and thus help with the migration: --use-constant-int-for-fixed-length-splat --use-constant-fp-for-fixed-length-splat --use-constant-int-for-scalable-splat --use-constant-fp-for-scalable-splat NOTE: No change is required to the bitcode format because types and values are handled separately. NOTE: For similar reasons as above, code generation doesn't work out-the-box.
1 parent 2cb61a1 commit 6c6baf8

File tree

8 files changed

+217
-38
lines changed

8 files changed

+217
-38
lines changed

llvm/include/llvm/IR/Constants.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class ConstantInt final : public ConstantData {
8181

8282
APInt Val;
8383

84-
ConstantInt(IntegerType *Ty, const APInt &V);
84+
ConstantInt(Type *Ty, const APInt &V);
8585

8686
void destroyConstantImpl();
8787

@@ -123,6 +123,12 @@ class ConstantInt final : public ConstantData {
123123
/// type is the integer type that corresponds to the bit width of the value.
124124
static ConstantInt *get(LLVMContext &Context, const APInt &V);
125125

126+
/// Return a ConstantInt with the specified value and an implied Type. The
127+
/// type is the vector type whose integer element type corresponds to the bit
128+
/// width of the value.
129+
static ConstantInt *get(LLVMContext &Context, ElementCount EC,
130+
const APInt &V);
131+
126132
/// Return a ConstantInt constructed from the string strStart with the given
127133
/// radix.
128134
static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix);
@@ -136,7 +142,7 @@ class ConstantInt final : public ConstantData {
136142
/// Return the constant's value.
137143
inline const APInt &getValue() const { return Val; }
138144

139-
/// getBitWidth - Return the bitwidth of this constant.
145+
/// getBitWidth - Return the scalar bitwidth of this constant.
140146
unsigned getBitWidth() const { return Val.getBitWidth(); }
141147

142148
/// Return the constant as a 64-bit unsigned integer value after it
@@ -281,6 +287,8 @@ class ConstantFP final : public ConstantData {
281287

282288
static Constant *get(Type *Ty, StringRef Str);
283289
static ConstantFP *get(LLVMContext &Context, const APFloat &V);
290+
static ConstantFP *get(LLVMContext &Context, ElementCount EC,
291+
const APFloat &V);
284292
static Constant *getNaN(Type *Ty, bool Negative = false,
285293
uint64_t Payload = 0);
286294
static Constant *getQNaN(Type *Ty, bool Negative = false,

llvm/lib/Bitcode/Reader/BitcodeReader.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,48 +3060,49 @@ Error BitcodeReader::parseConstants() {
30603060
V = Constant::getNullValue(CurTy);
30613061
break;
30623062
case bitc::CST_CODE_INTEGER: // INTEGER: [intval]
3063-
if (!CurTy->isIntegerTy() || Record.empty())
3063+
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
30643064
return error("Invalid integer const record");
30653065
V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
30663066
break;
30673067
case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
3068-
if (!CurTy->isIntegerTy() || Record.empty())
3068+
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
30693069
return error("Invalid wide integer const record");
30703070

3071-
APInt VInt =
3072-
readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
3073-
V = ConstantInt::get(Context, VInt);
3074-
3071+
auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
3072+
APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
3073+
V = ConstantInt::get(CurTy, VInt);
30753074
break;
30763075
}
30773076
case bitc::CST_CODE_FLOAT: { // FLOAT: [fpval]
30783077
if (Record.empty())
30793078
return error("Invalid float const record");
3080-
if (CurTy->isHalfTy())
3081-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
3082-
APInt(16, (uint16_t)Record[0])));
3083-
else if (CurTy->isBFloatTy())
3084-
V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
3085-
APInt(16, (uint32_t)Record[0])));
3086-
else if (CurTy->isFloatTy())
3087-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
3088-
APInt(32, (uint32_t)Record[0])));
3089-
else if (CurTy->isDoubleTy())
3090-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
3091-
APInt(64, Record[0])));
3092-
else if (CurTy->isX86_FP80Ty()) {
3079+
3080+
auto *ScalarTy = CurTy->getScalarType();
3081+
if (ScalarTy->isHalfTy())
3082+
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
3083+
APInt(16, (uint16_t)Record[0])));
3084+
else if (ScalarTy->isBFloatTy())
3085+
V = ConstantFP::get(
3086+
CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0])));
3087+
else if (ScalarTy->isFloatTy())
3088+
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(),
3089+
APInt(32, (uint32_t)Record[0])));
3090+
else if (ScalarTy->isDoubleTy())
3091+
V = ConstantFP::get(
3092+
CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0])));
3093+
else if (ScalarTy->isX86_FP80Ty()) {
30933094
// Bits are not stored the same way as a normal i80 APInt, compensate.
30943095
uint64_t Rearrange[2];
30953096
Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
30963097
Rearrange[1] = Record[0] >> 48;
3097-
V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
3098-
APInt(80, Rearrange)));
3099-
} else if (CurTy->isFP128Ty())
3100-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
3101-
APInt(128, Record)));
3102-
else if (CurTy->isPPC_FP128Ty())
3103-
V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
3104-
APInt(128, Record)));
3098+
V = ConstantFP::get(
3099+
CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange)));
3100+
} else if (ScalarTy->isFP128Ty())
3101+
V = ConstantFP::get(CurTy,
3102+
APFloat(APFloat::IEEEquad(), APInt(128, Record)));
3103+
else if (ScalarTy->isPPC_FP128Ty())
3104+
V = ConstantFP::get(
3105+
CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record)));
31053106
else
31063107
V = UndefValue::get(CurTy);
31073108
break;

llvm/lib/Bitcode/Writer/BitcodeWriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2624,7 +2624,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
26242624
}
26252625
} else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
26262626
Code = bitc::CST_CODE_FLOAT;
2627-
Type *Ty = CFP->getType();
2627+
Type *Ty = CFP->getType()->getScalarType();
26282628
if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
26292629
Ty->isDoubleTy()) {
26302630
Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());

llvm/lib/IR/AsmWriter.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,16 +1409,32 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
14091409
static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14101410
AsmWriterContext &WriterCtx) {
14111411
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
1412-
if (CI->getType()->isIntegerTy(1)) {
1413-
Out << (CI->getZExtValue() ? "true" : "false");
1414-
return;
1412+
if (CI->getType()->isVectorTy()) {
1413+
Out << "splat (";
1414+
WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
1415+
Out << " ";
14151416
}
1416-
Out << CI->getValue();
1417+
1418+
if (CI->getType()->getScalarType()->isIntegerTy(1))
1419+
Out << (CI->getZExtValue() ? "true" : "false");
1420+
else
1421+
Out << CI->getValue();
1422+
1423+
if (CI->getType()->isVectorTy())
1424+
Out << ")";
1425+
14171426
return;
14181427
}
14191428

14201429
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
14211430
const APFloat &APF = CFP->getValueAPF();
1431+
1432+
if (CFP->getType()->isVectorTy()) {
1433+
Out << "splat (";
1434+
WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
1435+
Out << " ";
1436+
}
1437+
14221438
if (&APF.getSemantics() == &APFloat::IEEEsingle() ||
14231439
&APF.getSemantics() == &APFloat::IEEEdouble()) {
14241440
// We would like to output the FP constant value in exponential notation,
@@ -1444,6 +1460,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14441460
// Reparse stringized version!
14451461
if (APFloat(APFloat::IEEEdouble(), StrVal).convertToDouble() == Val) {
14461462
Out << StrVal;
1463+
1464+
if (CFP->getType()->isVectorTy())
1465+
Out << ")";
1466+
14471467
return;
14481468
}
14491469
}
@@ -1469,6 +1489,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14691489
}
14701490
}
14711491
Out << format_hex(apf.bitcastToAPInt().getZExtValue(), 0, /*Upper=*/true);
1492+
1493+
if (CFP->getType()->isVectorTy())
1494+
Out << ")";
1495+
14721496
return;
14731497
}
14741498

@@ -1483,7 +1507,6 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14831507
/*Upper=*/true);
14841508
Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
14851509
/*Upper=*/true);
1486-
return;
14871510
} else if (&APF.getSemantics() == &APFloat::IEEEquad()) {
14881511
Out << 'L';
14891512
Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
@@ -1506,6 +1529,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
15061529
/*Upper=*/true);
15071530
} else
15081531
llvm_unreachable("Unsupported floating point type");
1532+
1533+
if (CFP->getType()->isVectorTy())
1534+
Out << ")";
1535+
15091536
return;
15101537
}
15111538

llvm/lib/IR/Constants.cpp

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@
3535
using namespace llvm;
3636
using namespace PatternMatch;
3737

38+
// As set of temporary options to help migrate how splats are represented.
39+
static cl::opt<bool> UseConstantIntForFixedLengthSplat(
40+
"use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
41+
cl::desc("Use ConstantInt's native fixed-length vector splat support."));
42+
static cl::opt<bool> UseConstantFPForFixedLengthSplat(
43+
"use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
44+
cl::desc("Use ConstantFP's native fixed-length vector splat support."));
45+
static cl::opt<bool> UseConstantIntForScalableSplat(
46+
"use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
47+
cl::desc("Use ConstantInt's native scalable vector splat support."));
48+
static cl::opt<bool> UseConstantFPForScalableSplat(
49+
"use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
50+
cl::desc("Use ConstantFP's native scalable vector splat support."));
51+
3852
//===----------------------------------------------------------------------===//
3953
// Constant Class
4054
//===----------------------------------------------------------------------===//
@@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
825839
// ConstantInt
826840
//===----------------------------------------------------------------------===//
827841

828-
ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
842+
ConstantInt::ConstantInt(Type *Ty, const APInt &V)
829843
: ConstantData(Ty, ConstantIntVal), Val(V) {
830-
assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
844+
assert(V.getBitWidth() ==
845+
cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
846+
"Invalid constant for type");
831847
}
832848

833849
ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
@@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
885901
return Slot.get();
886902
}
887903

904+
// Get a ConstantInt vector with each lane set to the same APInt.
905+
ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
906+
const APInt &V) {
907+
// Get an existing value or the insertion position.
908+
std::unique_ptr<ConstantInt> &Slot =
909+
Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
910+
if (!Slot) {
911+
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
912+
VectorType *VTy = VectorType::get(ITy, EC);
913+
Slot.reset(new ConstantInt(VTy, V));
914+
}
915+
916+
#ifndef NDEBUG
917+
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
918+
VectorType *VTy = VectorType::get(ITy, EC);
919+
assert(Slot->getType() == VTy);
920+
#endif
921+
return Slot.get();
922+
}
923+
888924
Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
889925
Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);
890926

@@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
10241060
return Slot.get();
10251061
}
10261062

1063+
// Get a ConstantFP vector with each lane set to the same APFloat.
1064+
ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
1065+
const APFloat &V) {
1066+
// Get an existing value or the insertion position.
1067+
std::unique_ptr<ConstantFP> &Slot =
1068+
Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
1069+
if (!Slot) {
1070+
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
1071+
VectorType *VTy = VectorType::get(EltTy, EC);
1072+
Slot.reset(new ConstantFP(VTy, V));
1073+
}
1074+
1075+
#ifndef NDEBUG
1076+
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
1077+
VectorType *VTy = VectorType::get(EltTy, EC);
1078+
assert(Slot->getType() == VTy);
1079+
#endif
1080+
return Slot.get();
1081+
}
1082+
10271083
Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
10281084
const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
10291085
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
@@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
10361092

10371093
ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
10381094
: ConstantData(Ty, ConstantFPVal), Val(V) {
1039-
assert(&V.getSemantics() == &Ty->getFltSemantics() &&
1095+
assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
10401096
"FP type Mismatch");
10411097
}
10421098

@@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
13841440

13851441
Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
13861442
if (!EC.isScalable()) {
1443+
// Maintain special handling of zero.
1444+
if (!V->isNullValue()) {
1445+
if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
1446+
return ConstantInt::get(V->getContext(), EC,
1447+
cast<ConstantInt>(V)->getValue());
1448+
if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
1449+
return ConstantFP::get(V->getContext(), EC,
1450+
cast<ConstantFP>(V)->getValue());
1451+
}
1452+
13871453
// If this splat is compatible with ConstantDataVector, use it instead of
13881454
// ConstantVector.
13891455
if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
@@ -1394,6 +1460,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
13941460
return get(Elts);
13951461
}
13961462

1463+
// Maintain special handling of zero.
1464+
if (!V->isNullValue()) {
1465+
if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
1466+
return ConstantInt::get(V->getContext(), EC,
1467+
cast<ConstantInt>(V)->getValue());
1468+
if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
1469+
return ConstantFP::get(V->getContext(), EC,
1470+
cast<ConstantFP>(V)->getValue());
1471+
}
1472+
13971473
Type *VTy = VectorType::get(V->getType(), EC);
13981474

13991475
if (V->isNullValue())

llvm/lib/IR/LLVMContextImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
119119
IntZeroConstants.clear();
120120
IntOneConstants.clear();
121121
IntConstants.clear();
122+
IntSplatConstants.clear();
122123
FPConstants.clear();
124+
FPSplatConstants.clear();
123125
CDSConstants.clear();
124126

125127
// Destroy attribute node lists.

llvm/lib/IR/LLVMContextImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,8 +1488,12 @@ class LLVMContextImpl {
14881488
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
14891489
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
14901490
DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
1491+
DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
1492+
IntSplatConstants;
14911493

14921494
DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
1495+
DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
1496+
FPSplatConstants;
14931497

14941498
FoldingSet<AttributeImpl> AttrsSet;
14951499
FoldingSet<AttributeListImpl> AttrsLists;

0 commit comments

Comments
 (0)