Skip to content

Commit e89bcfc

Browse files
authored
[SandboxIR] Add tracking for ShuffleVectorInst::commute. (llvm#106644)
Track it as an operand swap + a `setShuffleMask` and delegate to the `llvm::ShuffleVectorInst` implementation.
1 parent 4f403e8 commit e89bcfc

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ class ShuffleVectorInst final
11431143

11441144
/// Swap the operands and adjust the mask to preserve the semantics of the
11451145
/// instruction.
1146-
void commute() { cast<llvm::ShuffleVectorInst>(Val)->commute(); }
1146+
void commute();
11471147

11481148
/// Return true if a shufflevector instruction can be formed with the
11491149
/// specified operands.

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,6 +2185,13 @@ VectorType *ShuffleVectorInst::getType() const {
21852185
Ctx.getType(cast<llvm::ShuffleVectorInst>(Val)->getType()));
21862186
}
21872187

2188+
void ShuffleVectorInst::commute() {
2189+
Ctx.getTracker().emplaceIfTracking<ShuffleVectorSetMask>(this);
2190+
Ctx.getTracker().emplaceIfTracking<UseSwap>(getOperandUse(0),
2191+
getOperandUse(1));
2192+
cast<llvm::ShuffleVectorInst>(Val)->commute();
2193+
}
2194+
21882195
Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const {
21892196
return Ctx.getOrCreateConstant(
21902197
cast<llvm::ShuffleVectorInst>(Val)->getShuffleMaskForBitcode());

llvm/unittests/SandboxIR/TrackerTest.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ define void @foo(i32 %cond0, i32 %cond1) {
964964
EXPECT_EQ(Switch->findCaseDest(BB1), One);
965965
}
966966

967-
TEST_F(TrackerTest, ShuffleVectorInstSetters) {
967+
TEST_F(TrackerTest, ShuffleVectorInst) {
968968
parseIR(C, R"IR(
969969
define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
970970
%shuf = shufflevector <2 x i8> %v1, <2 x i8> %v2, <2 x i32> <i32 1, i32 2>
@@ -983,10 +983,22 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
983983
SmallVector<int, 2> OrigMask(SVI->getShuffleMask());
984984
Ctx.save();
985985
SVI->setShuffleMask(ArrayRef<int>({0, 0}));
986-
EXPECT_THAT(SVI->getShuffleMask(),
987-
testing::Not(testing::ElementsAreArray(OrigMask)));
986+
EXPECT_NE(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
988987
Ctx.revert();
989-
EXPECT_THAT(SVI->getShuffleMask(), testing::ElementsAreArray(OrigMask));
988+
EXPECT_EQ(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
989+
990+
// Check commute.
991+
auto *Op0 = SVI->getOperand(0);
992+
auto *Op1 = SVI->getOperand(1);
993+
Ctx.save();
994+
SVI->commute();
995+
EXPECT_EQ(SVI->getOperand(0), Op1);
996+
EXPECT_EQ(SVI->getOperand(1), Op0);
997+
EXPECT_NE(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
998+
Ctx.revert();
999+
EXPECT_EQ(SVI->getOperand(0), Op0);
1000+
EXPECT_EQ(SVI->getOperand(1), Op1);
1001+
EXPECT_EQ(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
9901002
}
9911003

9921004
TEST_F(TrackerTest, PossiblyDisjointInstSetters) {

0 commit comments

Comments
 (0)