Skip to content

Commit ce16521

Browse files
[GISel] Add support for scalable vectors in getGCDType
This function can be called from buildCopyToRegs where at least one of the types is a scalable vector type. This function crashed because it did not know how to handle scalable vector types. This patch extends the functionality of getGCDType to handle when at least one of the types is a scalable vector. getGCDType between a fixed and scalable vector is not implemented since the docstring of the function explains that getGCDType is used to build MERGE/UNMERGE instructions and we will never build a MERGE/UNMERGE between fixed and scalable vectors.
1 parent 59eadcd commit ce16521

File tree

3 files changed

+127
-33
lines changed

3 files changed

+127
-33
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,10 @@ LLT getCoverTy(LLT OrigTy, LLT TargetTy);
365365
/// If these are vectors with different element types, this will try to produce
366366
/// a vector with a compatible total size, but the element type of \p OrigTy. If
367367
/// this can't be satisfied, this will produce a scalar smaller than the
368-
/// original vector elements.
368+
/// original vector elements. It is an error to call this function where
369+
/// one argument is a fixed vector and the other is a scalable vector, since it
370+
/// is illegal to build a G_{MERGE|UNMERGE}_VALUES between fixed and scalable
371+
/// vectors.
369372
///
370373
/// In the worst case, this returns LLT::scalar(1)
371374
LLVM_READNONE

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,45 +1132,60 @@ LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {
11321132
}
11331133

11341134
LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
1135-
const unsigned OrigSize = OrigTy.getSizeInBits();
1136-
const unsigned TargetSize = TargetTy.getSizeInBits();
1137-
1138-
if (OrigSize == TargetSize)
1135+
if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
11391136
return OrigTy;
11401137

1141-
if (OrigTy.isVector()) {
1138+
if (OrigTy.isVector() && TargetTy.isVector()) {
11421139
LLT OrigElt = OrigTy.getElementType();
1143-
if (TargetTy.isVector()) {
1144-
LLT TargetElt = TargetTy.getElementType();
1145-
if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
1146-
int GCD = std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
1147-
return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt);
1148-
}
1149-
} else {
1150-
// If the source is a vector of pointers, return a pointer element.
1151-
if (OrigElt.getSizeInBits() == TargetSize)
1152-
return OrigElt;
1153-
}
1154-
1155-
unsigned GCD = std::gcd(OrigSize, TargetSize);
1156-
if (GCD == OrigElt.getSizeInBits())
1157-
return OrigElt;
1140+
LLT TargetElt = TargetTy.getElementType();
11581141

1159-
// If we can't produce the original element type, we have to use a smaller
1160-
// scalar.
1161-
if (GCD < OrigElt.getSizeInBits())
1162-
return LLT::scalar(GCD);
1163-
return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt);
1142+
// TODO: The docstring for this function says the intention is to use this
1143+
// function to build MERGE/UNMERGE instructions. It won't be the case that
1144+
// we generate a MERGE/UNMERGE between fixed and scalable vector types. We
1145+
// could implement getGCDType between the two in the future if there was a
1146+
// need, but it is not worth it now as this function should not be used in
1147+
// that way.
1148+
if ((OrigTy.isScalableVector() && TargetTy.isFixedVector()) ||
1149+
(OrigTy.isFixedVector() && TargetTy.isScalableVector()))
1150+
llvm_unreachable(
1151+
"getGCDType not implemented between fixed and scalable vectors.");
1152+
1153+
unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
1154+
OrigElt.getSizeInBits().getFixedValue(),
1155+
TargetTy.getElementCount().getKnownMinValue() *
1156+
TargetElt.getSizeInBits().getFixedValue());
1157+
if (GCD == OrigElt.getSizeInBits())
1158+
return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
1159+
OrigElt);
1160+
1161+
// Cannot produce original element type, but both have vscale in common.
1162+
if (GCD < OrigElt.getSizeInBits())
1163+
return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
1164+
GCD);
1165+
1166+
return LLT::vector(
1167+
ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
1168+
OrigTy.isScalable()),
1169+
OrigElt);
11641170
}
11651171

1166-
if (TargetTy.isVector()) {
1167-
// Try to preserve the original element type.
1168-
LLT TargetElt = TargetTy.getElementType();
1169-
if (TargetElt.getSizeInBits() == OrigSize)
1170-
return OrigTy;
1171-
}
1172+
// If one type is vector and the element size matches the scalar size, then
1173+
// the gcd is the scalar type.
1174+
if (OrigTy.isVector() &&
1175+
OrigTy.getElementType().getSizeInBits() == TargetTy.getSizeInBits())
1176+
return OrigTy.getElementType();
1177+
if (TargetTy.isVector() &&
1178+
TargetTy.getElementType().getSizeInBits() == OrigTy.getSizeInBits())
1179+
return OrigTy;
11721180

1173-
unsigned GCD = std::gcd(OrigSize, TargetSize);
1181+
// At this point, both types are either scalars of different type or one is a
1182+
// vector and one is a scalar. If both types are scalars, the GCD type is the
1183+
// GCD between the two scalar sizes. If one is vector and one is scalar, then
1184+
// the GCD type is the GCD between the scalar and the vector element size.
1185+
LLT OrigScalar = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;
1186+
LLT TargetScalar = TargetTy.isVector() ? TargetTy.getElementType() : TargetTy;
1187+
unsigned GCD = std::gcd(OrigScalar.getSizeInBits().getFixedValue(),
1188+
TargetScalar.getSizeInBits().getFixedValue());
11741189
return LLT::scalar(GCD);
11751190
}
11761191

llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,26 @@ static const LLT V6P0 = LLT::fixed_vector(6, P0);
4646
static const LLT V2P1 = LLT::fixed_vector(2, P1);
4747
static const LLT V4P1 = LLT::fixed_vector(4, P1);
4848

49+
static const LLT NXV1S1 = LLT::scalable_vector(1, S1);
50+
static const LLT NXV2S1 = LLT::scalable_vector(2, S1);
51+
static const LLT NXV3S1 = LLT::scalable_vector(3, S1);
52+
static const LLT NXV4S1 = LLT::scalable_vector(4, S1);
53+
54+
static const LLT NXV1S32 = LLT::scalable_vector(1, S32);
55+
static const LLT NXV2S32 = LLT::scalable_vector(2, S32);
56+
static const LLT NXV3S32 = LLT::scalable_vector(3, S32);
57+
static const LLT NXV4S32 = LLT::scalable_vector(4, S32);
58+
59+
static const LLT NXV1S64 = LLT::scalable_vector(1, S64);
60+
static const LLT NXV2S64 = LLT::scalable_vector(2, S64);
61+
static const LLT NXV3S64 = LLT::scalable_vector(3, S64);
62+
static const LLT NXV4S64 = LLT::scalable_vector(4, S64);
63+
64+
static const LLT NXV1P0 = LLT::scalable_vector(1, P0);
65+
static const LLT NXV2P0 = LLT::scalable_vector(2, P0);
66+
static const LLT NXV3P0 = LLT::scalable_vector(3, P0);
67+
static const LLT NXV4P0 = LLT::scalable_vector(4, P0);
68+
4969
TEST(GISelUtilsTest, getGCDType) {
5070
EXPECT_EQ(S1, getGCDType(S1, S1));
5171
EXPECT_EQ(S32, getGCDType(S32, S32));
@@ -152,6 +172,62 @@ TEST(GISelUtilsTest, getGCDType) {
152172

153173
EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::fixed_vector(3, 4), S8));
154174
EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::fixed_vector(3, 4)));
175+
176+
// Scalable -> Scalable
177+
EXPECT_EQ(NXV1S1, getGCDType(NXV1S1, NXV1S32));
178+
EXPECT_EQ(NXV1S32, getGCDType(NXV1S64, NXV1S32));
179+
EXPECT_EQ(NXV1S32, getGCDType(NXV1S32, NXV1S64));
180+
EXPECT_EQ(NXV1P0, getGCDType(NXV1P0, NXV1S64));
181+
EXPECT_EQ(NXV1S64, getGCDType(NXV1S64, NXV1P0));
182+
183+
EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV4S32));
184+
EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV4S32));
185+
EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV4S64));
186+
EXPECT_EQ(NXV4P0, getGCDType(NXV4P0, NXV4S64));
187+
EXPECT_EQ(NXV4S64, getGCDType(NXV4S64, NXV4P0));
188+
189+
EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV2S32));
190+
EXPECT_EQ(NXV1S64, getGCDType(NXV4S64, NXV2S32));
191+
EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV2S64));
192+
EXPECT_EQ(NXV2P0, getGCDType(NXV4P0, NXV2S64));
193+
EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV2P0));
194+
195+
EXPECT_EQ(NXV2S1, getGCDType(NXV2S1, NXV4S32));
196+
EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4S32));
197+
EXPECT_EQ(NXV2S32, getGCDType(NXV2S32, NXV4S64));
198+
EXPECT_EQ(NXV2P0, getGCDType(NXV2P0, NXV4S64));
199+
EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4P0));
200+
201+
EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S32));
202+
EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S32));
203+
EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S64));
204+
EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4S64));
205+
EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4P0));
206+
207+
EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S1));
208+
EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S32));
209+
EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S64));
210+
EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4P0));
211+
212+
// Scalable, Scalar
213+
214+
EXPECT_EQ(S1, getGCDType(NXV1S1, S1));
215+
EXPECT_EQ(S1, getGCDType(NXV1S1, S32));
216+
EXPECT_EQ(S1, getGCDType(NXV1S32, S1));
217+
EXPECT_EQ(S32, getGCDType(NXV1S32, S32));
218+
EXPECT_EQ(S32, getGCDType(NXV1S32, S64));
219+
EXPECT_EQ(S1, getGCDType(NXV2S32, S1));
220+
EXPECT_EQ(S32, getGCDType(NXV2S32, S32));
221+
EXPECT_EQ(S32, getGCDType(NXV2S32, S64));
222+
223+
EXPECT_EQ(S1, getGCDType(S1, NXV1S1));
224+
EXPECT_EQ(S1, getGCDType(S32, NXV1S1));
225+
EXPECT_EQ(S1, getGCDType(S1, NXV1S32));
226+
EXPECT_EQ(S32, getGCDType(S32, NXV1S32));
227+
EXPECT_EQ(S32, getGCDType(S64, NXV1S32));
228+
EXPECT_EQ(S1, getGCDType(S1, NXV2S32));
229+
EXPECT_EQ(S32, getGCDType(S32, NXV2S32));
230+
EXPECT_EQ(S32, getGCDType(S64, NXV2S32));
155231
}
156232

157233
TEST(GISelUtilsTest, getLCMType) {

0 commit comments

Comments
 (0)