Skip to content

Commit 1b3bce9

Browse files
[LLVM][IR] Add native vector support to ConstantInt & ConstantFP.
NOTE: For brevity the following talks about ConstantInt but everything extends to cover ConstantFP as well. Whilst ConstantInt::get() supports the creation of vectors whereby each lane has the same value, it achieves this via other constants: * ConstantVector for fixed-length vectors * ConstantExprs for scalable vectors However, ConstantExprs are being deprecated and ConstantVector is not space efficient for larger vector types. By extending ConstantInt we can represent vector splats by only storing the underlying scalar value. More specifically: * ConstantInt gains an ElementCount variant of get(). * LLVMContext is extended to map <EC,APInt>->ConstantInt. * BitcodeReader/Writer support is extended to allow vector types. Whilst this patch adds the base support, more work is required before it's production ready. For example, there's likely to be many places where isa<ConstantInt> assumes a scalar type. Accordingly the default behaviour of ConstantInt::get() remains unchanged but a set of flags are added to allow wider testing and thus help with the migration: --use-constant-int-for-fixed-length-splat --use-constant-fp-for-fixed-length-splat --use-constant-int-for-scalable-splat --use-constant-fp-for-scalable-splat NOTE: No change is required to the bitcode format because types and values are handled separately. NOTE: For similar reasons as above, code generation doesn't work out-the-box.
1 parent 970152b commit 1b3bce9

File tree

8 files changed

+217
-38
lines changed

8 files changed

+217
-38
lines changed

llvm/include/llvm/IR/Constants.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class ConstantInt final : public ConstantData {
8181

8282
APInt Val;
8383

84-
ConstantInt(IntegerType *Ty, const APInt &V);
84+
ConstantInt(Type *Ty, const APInt &V);
8585

8686
void destroyConstantImpl();
8787

@@ -123,6 +123,12 @@ class ConstantInt final : public ConstantData {
123123
/// type is the integer type that corresponds to the bit width of the value.
124124
static ConstantInt *get(LLVMContext &Context, const APInt &V);
125125

126+
/// Return a ConstantInt with the specified value and an implied Type. The
127+
/// type is the vector type whose integer element type corresponds to the bit
128+
/// width of the value.
129+
static ConstantInt *get(LLVMContext &Context, ElementCount EC,
130+
const APInt &V);
131+
126132
/// Return a ConstantInt constructed from the string strStart with the given
127133
/// radix.
128134
static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix);
@@ -136,7 +142,7 @@ class ConstantInt final : public ConstantData {
136142
/// Return the constant's value.
137143
inline const APInt &getValue() const { return Val; }
138144

139-
/// getBitWidth - Return the bitwidth of this constant.
145+
/// getBitWidth - Return the scalar bitwidth of this constant.
140146
unsigned getBitWidth() const { return Val.getBitWidth(); }
141147

142148
/// Return the constant as a 64-bit unsigned integer value after it
@@ -281,6 +287,8 @@ class ConstantFP final : public ConstantData {
281287

282288
static Constant *get(Type *Ty, StringRef Str);
283289
static ConstantFP *get(LLVMContext &Context, const APFloat &V);
290+
static ConstantFP *get(LLVMContext &Context, ElementCount EC,
291+
const APFloat &V);
284292
static Constant *getNaN(Type *Ty, bool Negative = false,
285293
uint64_t Payload = 0);
286294
static Constant *getQNaN(Type *Ty, bool Negative = false,

llvm/lib/Bitcode/Reader/BitcodeReader.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3057,48 +3057,49 @@ Error BitcodeReader::parseConstants() {
30573057
V = Constant::getNullValue(CurTy);
30583058
break;
30593059
case bitc::CST_CODE_INTEGER: // INTEGER: [intval]
3060-
if (!CurTy->isIntegerTy() || Record.empty())
3060+
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
30613061
return error("Invalid integer const record");
30623062
V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
30633063
break;
30643064
case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
3065-
if (!CurTy->isIntegerTy() || Record.empty())
3065+
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
30663066
return error("Invalid wide integer const record");
30673067

3068-
APInt VInt =
3069-
readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
3070-
V = ConstantInt::get(Context, VInt);
3071-
3068+
auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
3069+
APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
3070+
V = ConstantInt::get(CurTy, VInt);
30723071
break;
30733072
}
30743073
case bitc::CST_CODE_FLOAT: { // FLOAT: [fpval]
30753074
if (Record.empty())
30763075
return error("Invalid float const record");
3077-
if (CurTy->isHalfTy())
3078-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
3079-
APInt(16, (uint16_t)Record[0])));
3080-
else if (CurTy->isBFloatTy())
3081-
V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
3082-
APInt(16, (uint32_t)Record[0])));
3083-
else if (CurTy->isFloatTy())
3084-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
3085-
APInt(32, (uint32_t)Record[0])));
3086-
else if (CurTy->isDoubleTy())
3087-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
3088-
APInt(64, Record[0])));
3089-
else if (CurTy->isX86_FP80Ty()) {
3076+
3077+
auto *ScalarTy = CurTy->getScalarType();
3078+
if (ScalarTy->isHalfTy())
3079+
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
3080+
APInt(16, (uint16_t)Record[0])));
3081+
else if (ScalarTy->isBFloatTy())
3082+
V = ConstantFP::get(
3083+
CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0])));
3084+
else if (ScalarTy->isFloatTy())
3085+
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(),
3086+
APInt(32, (uint32_t)Record[0])));
3087+
else if (ScalarTy->isDoubleTy())
3088+
V = ConstantFP::get(
3089+
CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0])));
3090+
else if (ScalarTy->isX86_FP80Ty()) {
30903091
// Bits are not stored the same way as a normal i80 APInt, compensate.
30913092
uint64_t Rearrange[2];
30923093
Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
30933094
Rearrange[1] = Record[0] >> 48;
3094-
V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
3095-
APInt(80, Rearrange)));
3096-
} else if (CurTy->isFP128Ty())
3097-
V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
3098-
APInt(128, Record)));
3099-
else if (CurTy->isPPC_FP128Ty())
3100-
V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
3101-
APInt(128, Record)));
3095+
V = ConstantFP::get(
3096+
CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange)));
3097+
} else if (ScalarTy->isFP128Ty())
3098+
V = ConstantFP::get(CurTy,
3099+
APFloat(APFloat::IEEEquad(), APInt(128, Record)));
3100+
else if (ScalarTy->isPPC_FP128Ty())
3101+
V = ConstantFP::get(
3102+
CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record)));
31023103
else
31033104
V = UndefValue::get(CurTy);
31043105
break;

llvm/lib/Bitcode/Writer/BitcodeWriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2609,7 +2609,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
26092609
}
26102610
} else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
26112611
Code = bitc::CST_CODE_FLOAT;
2612-
Type *Ty = CFP->getType();
2612+
Type *Ty = CFP->getType()->getScalarType();
26132613
if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
26142614
Ty->isDoubleTy()) {
26152615
Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());

llvm/lib/IR/AsmWriter.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,16 +1394,32 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
13941394
static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
13951395
AsmWriterContext &WriterCtx) {
13961396
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
1397-
if (CI->getType()->isIntegerTy(1)) {
1398-
Out << (CI->getZExtValue() ? "true" : "false");
1399-
return;
1397+
if (CI->getType()->isVectorTy()) {
1398+
Out << "splat (";
1399+
WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
1400+
Out << " ";
14001401
}
1401-
Out << CI->getValue();
1402+
1403+
if (CI->getType()->getScalarType()->isIntegerTy(1))
1404+
Out << (CI->getZExtValue() ? "true" : "false");
1405+
else
1406+
Out << CI->getValue();
1407+
1408+
if (CI->getType()->isVectorTy())
1409+
Out << ")";
1410+
14021411
return;
14031412
}
14041413

14051414
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
14061415
const APFloat &APF = CFP->getValueAPF();
1416+
1417+
if (CFP->getType()->isVectorTy()) {
1418+
Out << "splat (";
1419+
WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
1420+
Out << " ";
1421+
}
1422+
14071423
if (&APF.getSemantics() == &APFloat::IEEEsingle() ||
14081424
&APF.getSemantics() == &APFloat::IEEEdouble()) {
14091425
// We would like to output the FP constant value in exponential notation,
@@ -1429,6 +1445,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14291445
// Reparse stringized version!
14301446
if (APFloat(APFloat::IEEEdouble(), StrVal).convertToDouble() == Val) {
14311447
Out << StrVal;
1448+
1449+
if (CFP->getType()->isVectorTy())
1450+
Out << ")";
1451+
14321452
return;
14331453
}
14341454
}
@@ -1454,6 +1474,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14541474
}
14551475
}
14561476
Out << format_hex(apf.bitcastToAPInt().getZExtValue(), 0, /*Upper=*/true);
1477+
1478+
if (CFP->getType()->isVectorTy())
1479+
Out << ")";
1480+
14571481
return;
14581482
}
14591483

@@ -1468,7 +1492,6 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14681492
/*Upper=*/true);
14691493
Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
14701494
/*Upper=*/true);
1471-
return;
14721495
} else if (&APF.getSemantics() == &APFloat::IEEEquad()) {
14731496
Out << 'L';
14741497
Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
@@ -1491,6 +1514,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
14911514
/*Upper=*/true);
14921515
} else
14931516
llvm_unreachable("Unsupported floating point type");
1517+
1518+
if (CFP->getType()->isVectorTy())
1519+
Out << ")";
1520+
14941521
return;
14951522
}
14961523

llvm/lib/IR/Constants.cpp

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@
3535
using namespace llvm;
3636
using namespace PatternMatch;
3737

38+
// As set of temporary options to help migrate how splats are represented.
39+
static cl::opt<bool> UseConstantIntForFixedLengthSplat(
40+
"use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
41+
cl::desc("Use ConstantInt's native fixed-length vector splat support."));
42+
static cl::opt<bool> UseConstantFPForFixedLengthSplat(
43+
"use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
44+
cl::desc("Use ConstantFP's native fixed-length vector splat support."));
45+
static cl::opt<bool> UseConstantIntForScalableSplat(
46+
"use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
47+
cl::desc("Use ConstantInt's native scalable vector splat support."));
48+
static cl::opt<bool> UseConstantFPForScalableSplat(
49+
"use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
50+
cl::desc("Use ConstantFP's native scalable vector splat support."));
51+
3852
//===----------------------------------------------------------------------===//
3953
// Constant Class
4054
//===----------------------------------------------------------------------===//
@@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
825839
// ConstantInt
826840
//===----------------------------------------------------------------------===//
827841

828-
ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
842+
ConstantInt::ConstantInt(Type *Ty, const APInt &V)
829843
: ConstantData(Ty, ConstantIntVal), Val(V) {
830-
assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
844+
assert(V.getBitWidth() ==
845+
cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
846+
"Invalid constant for type");
831847
}
832848

833849
ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
@@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
885901
return Slot.get();
886902
}
887903

904+
// Get a ConstantInt vector with each lane set to the same APInt.
905+
ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
906+
const APInt &V) {
907+
// Get an existing value or the insertion position.
908+
std::unique_ptr<ConstantInt> &Slot =
909+
Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
910+
if (!Slot) {
911+
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
912+
VectorType *VTy = VectorType::get(ITy, EC);
913+
Slot.reset(new ConstantInt(VTy, V));
914+
}
915+
916+
#ifndef NDEBUG
917+
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
918+
VectorType *VTy = VectorType::get(ITy, EC);
919+
assert(Slot->getType() == VTy);
920+
#endif
921+
return Slot.get();
922+
}
923+
888924
Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
889925
Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);
890926

@@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
10241060
return Slot.get();
10251061
}
10261062

1063+
// Get a ConstantFP vector with each lane set to the same APFloat.
1064+
ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
1065+
const APFloat &V) {
1066+
// Get an existing value or the insertion position.
1067+
std::unique_ptr<ConstantFP> &Slot =
1068+
Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
1069+
if (!Slot) {
1070+
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
1071+
VectorType *VTy = VectorType::get(EltTy, EC);
1072+
Slot.reset(new ConstantFP(VTy, V));
1073+
}
1074+
1075+
#ifndef NDEBUG
1076+
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
1077+
VectorType *VTy = VectorType::get(EltTy, EC);
1078+
assert(Slot->getType() == VTy);
1079+
#endif
1080+
return Slot.get();
1081+
}
1082+
10271083
Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
10281084
const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
10291085
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
@@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
10361092

10371093
ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
10381094
: ConstantData(Ty, ConstantFPVal), Val(V) {
1039-
assert(&V.getSemantics() == &Ty->getFltSemantics() &&
1095+
assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
10401096
"FP type Mismatch");
10411097
}
10421098

@@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
13841440

13851441
Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
13861442
if (!EC.isScalable()) {
1443+
// Maintain special handling of zero.
1444+
if (!V->isNullValue()) {
1445+
if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
1446+
return ConstantInt::get(V->getContext(), EC,
1447+
cast<ConstantInt>(V)->getValue());
1448+
if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
1449+
return ConstantFP::get(V->getContext(), EC,
1450+
cast<ConstantFP>(V)->getValue());
1451+
}
1452+
13871453
// If this splat is compatible with ConstantDataVector, use it instead of
13881454
// ConstantVector.
13891455
if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
@@ -1394,6 +1460,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
13941460
return get(Elts);
13951461
}
13961462

1463+
// Maintain special handling of zero.
1464+
if (!V->isNullValue()) {
1465+
if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
1466+
return ConstantInt::get(V->getContext(), EC,
1467+
cast<ConstantInt>(V)->getValue());
1468+
if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
1469+
return ConstantFP::get(V->getContext(), EC,
1470+
cast<ConstantFP>(V)->getValue());
1471+
}
1472+
13971473
Type *VTy = VectorType::get(V->getType(), EC);
13981474

13991475
if (V->isNullValue())

llvm/lib/IR/LLVMContextImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
119119
IntZeroConstants.clear();
120120
IntOneConstants.clear();
121121
IntConstants.clear();
122+
IntSplatConstants.clear();
122123
FPConstants.clear();
124+
FPSplatConstants.clear();
123125
CDSConstants.clear();
124126

125127
// Destroy attribute node lists.

llvm/lib/IR/LLVMContextImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,8 +1488,12 @@ class LLVMContextImpl {
14881488
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
14891489
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
14901490
DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
1491+
DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
1492+
IntSplatConstants;
14911493

14921494
DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
1495+
DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
1496+
FPSplatConstants;
14931497

14941498
FoldingSet<AttributeImpl> AttrsSet;
14951499
FoldingSet<AttributeListImpl> AttrsLists;

0 commit comments

Comments
 (0)