Skip to content

[GISel] Add support for scalable vectors in getLCMType #80306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,13 @@ Register getFunctionLiveInPhysReg(MachineFunction &MF,
const TargetRegisterClass &RC,
const DebugLoc &DL, LLT RegTy = LLT());

/// Return the least common multiple type of \p OrigTy and \p TargetTy, by changing the
/// number of vector elements or scalar bitwidth. The intent is a
/// Return the least common multiple type of \p OrigTy and \p TargetTy, by
/// changing the number of vector elements or scalar bitwidth. The intent is a
/// G_MERGE_VALUES, G_BUILD_VECTOR, or G_CONCAT_VECTORS can be constructed from
/// \p OrigTy elements, and unmerged into \p TargetTy
/// \p OrigTy elements, and unmerged into \p TargetTy. It is an error to call
/// this function where one argument is a fixed vector and the other is a
/// scalable vector, since it is illegal to build a G_{MERGE|UNMERGE}_VALUES
/// between fixed and scalable vectors.
LLVM_READNONE
LLT getLCMType(LLT OrigTy, LLT TargetTy);

Expand Down
85 changes: 53 additions & 32 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,49 +1071,70 @@ void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) {
}

LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) {
const unsigned OrigSize = OrigTy.getSizeInBits();
const unsigned TargetSize = TargetTy.getSizeInBits();

if (OrigSize == TargetSize)
if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
return OrigTy;

if (OrigTy.isVector()) {
const LLT OrigElt = OrigTy.getElementType();

if (TargetTy.isVector()) {
const LLT TargetElt = TargetTy.getElementType();
if (OrigTy.isVector() && TargetTy.isVector()) {
LLT OrigElt = OrigTy.getElementType();
LLT TargetElt = TargetTy.getElementType();

if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
int GCDElts =
std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
// Prefer the original element type.
ElementCount Mul = OrigTy.getElementCount() * TargetTy.getNumElements();
return LLT::vector(Mul.divideCoefficientBy(GCDElts),
OrigTy.getElementType());
}
} else {
if (OrigElt.getSizeInBits() == TargetSize)
return OrigTy;
// TODO: The docstring for this function says the intention is to use this
// function to build MERGE/UNMERGE instructions. It won't be the case that
// we generate a MERGE/UNMERGE between fixed and scalable vector types. We
// could implement getLCMType between the two in the future if there was a
// need, but it is not worth it now as this function should not be used in
// that way.
assert(((OrigTy.isScalableVector() && !TargetTy.isFixedVector()) ||
(OrigTy.isFixedVector() && !TargetTy.isScalableVector())) &&
"getLCMType not implemented between fixed and scalable vectors.");

if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
int GCDMinElts = std::gcd(OrigTy.getElementCount().getKnownMinValue(),
TargetTy.getElementCount().getKnownMinValue());
// Prefer the original element type.
ElementCount Mul = OrigTy.getElementCount().multiplyCoefficientBy(
TargetTy.getElementCount().getKnownMinValue());
return LLT::vector(Mul.divideCoefficientBy(GCDMinElts),
OrigTy.getElementType());
}

unsigned LCMSize = std::lcm(OrigSize, TargetSize);
return LLT::fixed_vector(LCMSize / OrigElt.getSizeInBits(), OrigElt);
unsigned LCM = std::lcm(OrigTy.getSizeInBits().getKnownMinValue(),
TargetTy.getSizeInBits().getKnownMinValue());
return LLT::vector(
ElementCount::get(LCM / OrigElt.getSizeInBits(), OrigTy.isScalable()),
OrigElt);
}

if (TargetTy.isVector()) {
unsigned LCMSize = std::lcm(OrigSize, TargetSize);
return LLT::fixed_vector(LCMSize / OrigSize, OrigTy);
// One type is scalar, one type is vector
if (OrigTy.isVector() || TargetTy.isVector()) {
LLT VecTy = OrigTy.isVector() ? OrigTy : TargetTy;
LLT ScalarTy = OrigTy.isVector() ? TargetTy : OrigTy;
LLT EltTy = VecTy.getElementType();
LLT OrigEltTy = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;

// Prefer scalar type from OrigTy.
if (EltTy.getSizeInBits() == ScalarTy.getSizeInBits())
return LLT::vector(VecTy.getElementCount(), OrigEltTy);

// Different size scalars. Create vector with the same total size.
// LCM will take fixed/scalable from VecTy.
unsigned LCM = std::lcm(EltTy.getSizeInBits().getFixedValue() *
VecTy.getElementCount().getKnownMinValue(),
ScalarTy.getSizeInBits().getFixedValue());
// Prefer type from OrigTy
return LLT::vector(ElementCount::get(LCM / OrigEltTy.getSizeInBits(),
VecTy.getElementCount().isScalable()),
OrigEltTy);
}

unsigned LCMSize = std::lcm(OrigSize, TargetSize);

// At this point, both types are scalars of different size
unsigned LCM = std::lcm(OrigTy.getSizeInBits().getFixedValue(),
TargetTy.getSizeInBits().getFixedValue());
// Preserve pointer types.
if (LCMSize == OrigSize)
if (LCM == OrigTy.getSizeInBits())
return OrigTy;
if (LCMSize == TargetSize)
if (LCM == TargetTy.getSizeInBits())
return TargetTy;

return LLT::scalar(LCMSize);
return LLT::scalar(LCM);
}

LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {
Expand Down
87 changes: 87 additions & 0 deletions llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,37 @@ static const LLT V6P0 = LLT::fixed_vector(6, P0);
static const LLT V2P1 = LLT::fixed_vector(2, P1);
static const LLT V4P1 = LLT::fixed_vector(4, P1);

static const LLT NXV1S1 = LLT::scalable_vector(1, S1);
static const LLT NXV2S1 = LLT::scalable_vector(2, S1);
static const LLT NXV3S1 = LLT::scalable_vector(3, S1);
static const LLT NXV4S1 = LLT::scalable_vector(4, S1);
static const LLT NXV12S1 = LLT::scalable_vector(12, S1);
static const LLT NXV32S1 = LLT::scalable_vector(32, S1);
static const LLT NXV64S1 = LLT::scalable_vector(64, S1);
static const LLT NXV128S1 = LLT::scalable_vector(128, S1);
static const LLT NXV384S1 = LLT::scalable_vector(384, S1);

static const LLT NXV1S32 = LLT::scalable_vector(1, S32);
static const LLT NXV2S32 = LLT::scalable_vector(2, S32);
static const LLT NXV3S32 = LLT::scalable_vector(3, S32);
static const LLT NXV4S32 = LLT::scalable_vector(4, S32);
static const LLT NXV8S32 = LLT::scalable_vector(8, S32);
static const LLT NXV12S32 = LLT::scalable_vector(12, S32);
static const LLT NXV24S32 = LLT::scalable_vector(24, S32);

static const LLT NXV1S64 = LLT::scalable_vector(1, S64);
static const LLT NXV2S64 = LLT::scalable_vector(2, S64);
static const LLT NXV3S64 = LLT::scalable_vector(3, S64);
static const LLT NXV4S64 = LLT::scalable_vector(4, S64);
static const LLT NXV6S64 = LLT::scalable_vector(6, S64);
static const LLT NXV12S64 = LLT::scalable_vector(12, S64);

static const LLT NXV1P0 = LLT::scalable_vector(1, P0);
static const LLT NXV2P0 = LLT::scalable_vector(2, P0);
static const LLT NXV3P0 = LLT::scalable_vector(3, P0);
static const LLT NXV4P0 = LLT::scalable_vector(4, P0);
static const LLT NXV12P0 = LLT::scalable_vector(12, P0);

TEST(GISelUtilsTest, getGCDType) {
EXPECT_EQ(S1, getGCDType(S1, S1));
EXPECT_EQ(S32, getGCDType(S32, S32));
Expand Down Expand Up @@ -244,6 +275,62 @@ TEST(GISelUtilsTest, getLCMType) {

EXPECT_EQ(V2S64, getLCMType(V2S64, P1));
EXPECT_EQ(V4P1, getLCMType(P1, V2S64));

// Scalable, Scalable
EXPECT_EQ(NXV32S1, getLCMType(NXV1S1, NXV1S32));
EXPECT_EQ(NXV1S64, getLCMType(NXV1S64, NXV1S32));
EXPECT_EQ(NXV2S32, getLCMType(NXV1S32, NXV1S64));
EXPECT_EQ(NXV1P0, getLCMType(NXV1P0, NXV1S64));
EXPECT_EQ(NXV1S64, getLCMType(NXV1S64, NXV1P0));

EXPECT_EQ(NXV128S1, getLCMType(NXV4S1, NXV4S32));
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV4S32));
EXPECT_EQ(NXV8S32, getLCMType(NXV4S32, NXV4S64));
EXPECT_EQ(NXV4P0, getLCMType(NXV4P0, NXV4S64));
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV4P0));

EXPECT_EQ(NXV64S1, getLCMType(NXV4S1, NXV2S32));
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV2S32));
EXPECT_EQ(NXV4S32, getLCMType(NXV4S32, NXV2S64));
EXPECT_EQ(NXV4P0, getLCMType(NXV4P0, NXV2S64));
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV2P0));

EXPECT_EQ(NXV128S1, getLCMType(NXV2S1, NXV4S32));
EXPECT_EQ(NXV2S64, getLCMType(NXV2S64, NXV4S32));
EXPECT_EQ(NXV8S32, getLCMType(NXV2S32, NXV4S64));
EXPECT_EQ(NXV4P0, getLCMType(NXV2P0, NXV4S64));
EXPECT_EQ(NXV4S64, getLCMType(NXV2S64, NXV4P0));

EXPECT_EQ(NXV384S1, getLCMType(NXV3S1, NXV4S32));
EXPECT_EQ(NXV6S64, getLCMType(NXV3S64, NXV4S32));
EXPECT_EQ(NXV24S32, getLCMType(NXV3S32, NXV4S64));
EXPECT_EQ(NXV12P0, getLCMType(NXV3P0, NXV4S64));
EXPECT_EQ(NXV12S64, getLCMType(NXV3S64, NXV4P0));

EXPECT_EQ(NXV12S1, getLCMType(NXV3S1, NXV4S1));
EXPECT_EQ(NXV12S32, getLCMType(NXV3S32, NXV4S32));
EXPECT_EQ(NXV12S64, getLCMType(NXV3S64, NXV4S64));
EXPECT_EQ(NXV12P0, getLCMType(NXV3P0, NXV4P0));

// Scalable, Scalar

EXPECT_EQ(NXV1S1, getLCMType(NXV1S1, S1));
EXPECT_EQ(NXV32S1, getLCMType(NXV1S1, S32));
EXPECT_EQ(NXV1S32, getLCMType(NXV1S32, S1));
EXPECT_EQ(NXV1S32, getLCMType(NXV1S32, S32));
EXPECT_EQ(NXV2S32, getLCMType(NXV1S32, S64));
EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S1));
EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S32));
EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S64));

EXPECT_EQ(NXV1S1, getLCMType(S1, NXV1S1));
EXPECT_EQ(NXV1S32, getLCMType(S32, NXV1S1));
EXPECT_EQ(NXV32S1, getLCMType(S1, NXV1S32));
EXPECT_EQ(NXV1S32, getLCMType(S32, NXV1S32));
EXPECT_EQ(NXV1S64, getLCMType(S64, NXV1S32));
EXPECT_EQ(NXV64S1, getLCMType(S1, NXV2S32));
EXPECT_EQ(NXV2S32, getLCMType(S32, NXV2S32));
EXPECT_EQ(NXV1S64, getLCMType(S64, NXV2S32));
}

TEST_F(AArch64GISelMITest, ConstFalseTest) {
Expand Down