Skip to content

[SandboxIR] Implement BranchInst #100063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class Context;
class Function;
class Instruction;
class SelectInst;
class BranchInst;
class LoadInst;
class ReturnInst;
class StoreInst;
Expand Down Expand Up @@ -179,6 +180,7 @@ class Value {
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
Expand Down Expand Up @@ -343,6 +345,14 @@ class User : public Value {
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()

void swapOperandsInternal(unsigned OpIdxA, unsigned OpIdxB) {
assert(OpIdxA < getNumOperands() && "OpIdxA out of bounds!");
assert(OpIdxB < getNumOperands() && "OpIdxB out of bounds!");
auto UseA = getOperandUse(OpIdxA);
auto UseB = getOperandUse(OpIdxB);
UseA.swap(UseB);
}

#ifndef NDEBUG
void verifyUserOfLLVMUse(const llvm::Use &Use) const;
#endif // NDEBUG
Expand Down Expand Up @@ -504,6 +514,7 @@ class Instruction : public sandboxir::User {
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -617,6 +628,100 @@ class SelectInst : public Instruction {
#endif
};

class BranchInst : public Instruction {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
: Instruction(ClassID::Br, Opcode::Br, BI, Ctx) {}
friend Context; // for BranchInst()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static BranchInst *create(BasicBlock *IfTrue, Instruction *InsertBefore,
Context &Ctx);
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
Context &Ctx);
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't IfFalse == nullptr be the one that creates the instruction with only the true branch? That way you save creating two functions or put it as a default parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would make sense, but I guess it is not as explicit as having a separate function. This is how it's done in llvm::BranchInst so I just created a similar API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Value *Cond, Instruction *InsertBefore,
Context &Ctx);
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, BasicBlock *InsertAtEnd, Context &Ctx);
/// For isa/dyn_cast.
static bool classof(const Value *From);
bool isUnconditional() const {
return cast<llvm::BranchInst>(Val)->isUnconditional();
}
bool isConditional() const {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy-paste?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need both?

Copy link
Contributor Author

@vporpo vporpo Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these function exist in llvm::BranchInst, so since we are trying to create a similar API I guess we should keep both.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore that, I was too fast.

return cast<llvm::BranchInst>(Val)->isConditional();
}
Value *getCondition() const;
void setCondition(Value *V) { setOperand(0, V); }
unsigned getNumSuccessors() const { return 1 + isConditional(); }

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number + bool?

Copy link

@tschuett tschuett Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return isConditional() ? 2 : 1;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I could have casted it to unsigned first, but this is how it's done in llvm::BranchInst::getNumSuccessors() so I just copied it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just afraid some compiler, including future Clang will complain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry too much about it, it's legal c++. Casting bool to int gives us 0 or 1.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore me again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's always good to point out things that look strange or potentially wrong, so thanks for the comments.

BasicBlock *getSuccessor(unsigned SuccIdx) const;
void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
void swapSuccessors() { swapOperandsInternal(1, 2); }

private:
struct LLVMBBToSBBB {
Context &Ctx;
LLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
BasicBlock *operator()(llvm::BasicBlock *BB) const;
};

struct ConstLLVMBBToSBBB {
Context &Ctx;
ConstLLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
const BasicBlock *operator()(const llvm::BasicBlock *BB) const;
};

public:
using sb_succ_op_iterator =
mapped_iterator<llvm::BranchInst::succ_op_iterator, LLVMBBToSBBB>;
iterator_range<sb_succ_op_iterator> successors() {
iterator_range<llvm::BranchInst::succ_op_iterator> LLVMRange =
cast<llvm::BranchInst>(Val)->successors();
LLVMBBToSBBB BBMap(Ctx);
sb_succ_op_iterator MappedBegin = map_iterator(LLVMRange.begin(), BBMap);
sb_succ_op_iterator MappedEnd = map_iterator(LLVMRange.end(), BBMap);
return make_range(MappedBegin, MappedEnd);
}

using const_sb_succ_op_iterator =
mapped_iterator<llvm::BranchInst::const_succ_op_iterator,
ConstLLVMBBToSBBB>;
iterator_range<const_sb_succ_op_iterator> successors() const {
iterator_range<llvm::BranchInst::const_succ_op_iterator> ConstLLVMRange =
static_cast<const llvm::BranchInst *>(cast<llvm::BranchInst>(Val))
->successors();
ConstLLVMBBToSBBB ConstBBMap(Ctx);
const_sb_succ_op_iterator ConstMappedBegin =
map_iterator(ConstLLVMRange.begin(), ConstBBMap);
const_sb_succ_op_iterator ConstMappedEnd =
map_iterator(ConstLLVMRange.end(), ConstBBMap);
return make_range(ConstMappedBegin, ConstMappedEnd);
}

#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::BranchInst>(Val) && "Expected BranchInst!");
}
friend raw_ostream &operator<<(raw_ostream &OS, const BranchInst &BI) {
BI.dump(OS);
return OS;
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

class LoadInst final : public Instruction {
/// Use LoadInst::create() instead of calling the constructor.
LoadInst(llvm::LoadInst *LI, Context &Ctx)
Expand Down Expand Up @@ -870,6 +975,8 @@ class Context {

SelectInst *createSelectInst(llvm::SelectInst *SI);
friend SelectInst; // For createSelectInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
friend LoadInst; // For createLoadInst()
StoreInst *createStoreInst(llvm::StoreInst *SI);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ DEF_USER(Constant, Constant)
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
Expand Down
21 changes: 21 additions & 0 deletions llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ class UseSet : public IRChangeBase {
#endif
};

/// Tracks swapping a Use with another Use.
class UseSwap : public IRChangeBase {
Use ThisUse;
Use OtherUse;

public:
UseSwap(const Use &ThisUse, const Use &OtherUse, Tracker &Tracker)
: IRChangeBase(Tracker), ThisUse(ThisUse), OtherUse(OtherUse) {
assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!");
}
void revert() final { ThisUse.swap(OtherUse); }
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "UseSwap";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class EraseFromParent : public IRChangeBase {
/// Contains all the data we need to restore an "erased" (i.e., detached)
/// instruction: the instruction itself and its operands in order.
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/Use.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Use {
void set(Value *V);
class User *getUser() const { return Usr; }
unsigned getOperandNo() const;
void swap(Use &OtherUse);
Context *getContext() const { return Ctx; }
bool operator==(const Use &Other) const {
assert(Ctx == Other.Ctx && "Contexts differ!");
Expand Down
96 changes: 96 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ void Use::set(Value *V) { LLVMUse->set(V->Val); }

unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }

void Use::swap(Use &OtherUse) {
auto &Tracker = Ctx->getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<UseSwap>(*this, OtherUse, Tracker));
LLVMUse->swap(*OtherUse.LLVMUse);
}

#ifndef NDEBUG
void Use::dump(raw_ostream &OS) const {
Value *Def = nullptr;
Expand Down Expand Up @@ -500,6 +507,85 @@ void SelectInst::dump() const {
}
#endif // NDEBUG

BranchInst *BranchInst::create(BasicBlock *IfTrue, Instruction *InsertBefore,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
llvm::BranchInst *NewBr =
Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
return Ctx.createBranchInst(NewBr);
}

BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::BranchInst *NewBr =
Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
return Ctx.createBranchInst(NewBr);
}

BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, Instruction *InsertBefore,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
llvm::BranchInst *NewBr =
Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
cast<llvm::BasicBlock>(IfFalse->Val));
return Ctx.createBranchInst(NewBr);
}

BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, BasicBlock *InsertAtEnd,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::BranchInst *NewBr =
Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
cast<llvm::BasicBlock>(IfFalse->Val));
return Ctx.createBranchInst(NewBr);
}

bool BranchInst::classof(const Value *From) {
return From->getSubclassID() == ClassID::Br;
}

Value *BranchInst::getCondition() const {
assert(isConditional() && "Cannot get condition of an uncond branch!");
return Ctx.getValue(cast<llvm::BranchInst>(Val)->getCondition());
}

BasicBlock *BranchInst::getSuccessor(unsigned SuccIdx) const {
assert(SuccIdx < getNumSuccessors() &&
"Successor # out of range for Branch!");
return cast_or_null<BasicBlock>(
Ctx.getValue(cast<llvm::BranchInst>(Val)->getSuccessor(SuccIdx)));
}

void BranchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
assert((Idx == 0 || Idx == 1) && "Out of bounds!");
setOperand(2u - Idx, NewSucc);
}

BasicBlock *BranchInst::LLVMBBToSBBB::operator()(llvm::BasicBlock *BB) const {
return cast<BasicBlock>(Ctx.getValue(BB));
}
const BasicBlock *
BranchInst::ConstLLVMBBToSBBB::operator()(const llvm::BasicBlock *BB) const {
return cast<BasicBlock>(Ctx.getValue(BB));
}
#ifndef NDEBUG
void BranchInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}
void BranchInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
Expand Down Expand Up @@ -758,6 +844,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
case llvm::Instruction::Br: {
auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
return It->second.get();
}
case llvm::Instruction::Load: {
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
Expand Down Expand Up @@ -796,6 +887,11 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}

BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
return cast<BranchInst>(registerValue(std::move(NewPtr)));
}

LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
return cast<LoadInst>(registerValue(std::move(NewPtr)));
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/SandboxIR/Tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void UseSet::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void UseSwap::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Tracker::~Tracker() {
Expand Down
Loading
Loading