Skip to content

[WebAssembly] Handle block and polymorphic stack in AsmTypeCheck #110770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,9 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {

void addBlockTypeOperand(OperandVector &Operands, SMLoc NameLoc,
WebAssembly::BlockType BT) {
if (BT != WebAssembly::BlockType::Void) {
if (BT == WebAssembly::BlockType::Void) {
TC.setLastSig(wasm::WasmSignature{});
} else {
Comment on lines +501 to +503
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, when we have a non-void block we set the signature in the type checker, but not when we have a void block. So if we have a non-void block and then a void block, the type checker incorrectly thought void block's signature was the same as the previous (non-void ) one.

wasm::WasmSignature Sig({static_cast<wasm::ValType>(BT)}, {});
TC.setLastSig(Sig);
NestingStack.back().Sig = Sig;
Expand Down Expand Up @@ -1002,7 +1004,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
auto *Signature = Ctx.createWasmSignature();
if (parseSignature(Signature))
return ParseStatus::Failure;
TC.funcDecl(*Signature);
if (CurrentState == FunctionStart)
TC.funcDecl(*Signature);
Comment on lines +1007 to +1008
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.functype directive does not only happen at the start of a function but also just to declare a function type that does not have a definition associated with it. But we used to set the current function's signature whenever we parsed a .functype. This resulted (incorrectly) in calling funcDecl twice, and pushing to BlockInfoStack twice.

WasmSym->setSignature(Signature);
WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
TOut.emitFunctionType(WasmSym);
Expand Down
176 changes: 115 additions & 61 deletions llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,

void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
BlockInfoStack.push_back({Sig, 0, false});
}

void WebAssemblyAsmTypeCheck::localDecl(
Expand All @@ -64,14 +63,15 @@ void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
}

bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
// If we're currently in unreachable code, we suppress errors completely.
if (Unreachable)
return false;
dumpTypeStack("current stack: ");
return Parser.Error(ErrorLoc, Msg);
}

bool WebAssemblyAsmTypeCheck::match(StackType TypeA, StackType TypeB) {
// These should have been filtered out in checkTypes()
assert(!std::get_if<Polymorphic>(&TypeA) &&
!std::get_if<Polymorphic>(&TypeB));

if (TypeA == TypeB)
return false;
if (std::get_if<Any>(&TypeA) || std::get_if<Any>(&TypeB))
Expand All @@ -90,6 +90,10 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
size_t StartPos) {
SmallVector<std::string, 4> TypeStrs;
for (auto I = Types.size(); I > StartPos; I--) {
if (std::get_if<Polymorphic>(&Types[I - 1])) {
TypeStrs.push_back("...");
break;
}
if (std::get_if<Any>(&Types[I - 1]))
TypeStrs.push_back("any");
else if (std::get_if<Ref>(&Types[I - 1]))
Expand Down Expand Up @@ -131,29 +135,48 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
bool ExactMatch) {
auto StackI = Stack.size();
auto TypeI = Types.size();
assert(!BlockInfoStack.empty());
auto BlockStackStartPos = BlockInfoStack.back().StackStartPos;
bool Error = false;
bool PolymorphicStack = false;
// Compare elements one by one from the stack top
for (; StackI > 0 && TypeI > 0; StackI--, TypeI--) {
for (;StackI > BlockStackStartPos && TypeI > 0; StackI--, TypeI--) {
// If the stack is polymorphic, we assume all types in 'Types' have been
// compared and matched
if (std::get_if<Polymorphic>(&Stack[StackI - 1])) {
TypeI = 0;
break;
}
if (match(Stack[StackI - 1], Types[TypeI - 1])) {
Error = true;
break;
}
}

// If the stack top is polymorphic, the stack is in the polymorphic state.
if (StackI > BlockStackStartPos &&
std::get_if<Polymorphic>(&Stack[StackI - 1]))
PolymorphicStack = true;

// Even if no match failure has happened in the loop above, if not all
// elements of Types has been matched, that means we don't have enough
// elements on the stack.
//
// Also, if not all elements of the Stack has been matched and when
// 'ExactMatch' is true, that means we have superfluous elements remaining on
// the stack (e.g. at the end of a function).
if (TypeI > 0 || (ExactMatch && StackI > 0))
// 'ExactMatch' is true and the current stack is not polymorphic, that means
// we have superfluous elements remaining on the stack (e.g. at the end of a
// function).
if (TypeI > 0 ||
(ExactMatch && !PolymorphicStack && StackI > BlockStackStartPos))
Error = true;

if (!Error)
return false;

auto StackStartPos =
ExactMatch ? 0 : std::max(0, (int)Stack.size() - (int)Types.size());
auto StackStartPos = ExactMatch
? BlockStackStartPos
: std::max((int)BlockStackStartPos,
(int)Stack.size() - (int)Types.size());
return typeError(ErrorLoc, "type mismatch, expected " +
getTypesString(Types, 0) + " but got " +
getTypesString(Stack, StackStartPos));
Expand All @@ -169,9 +192,13 @@ bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
ArrayRef<StackType> Types,
bool ExactMatch) {
bool Error = checkTypes(ErrorLoc, Types, ExactMatch);
auto NumPops = std::min(Stack.size(), Types.size());
for (size_t I = 0, E = NumPops; I != E; I++)
auto NumPops = std::min(Stack.size() - BlockInfoStack.back().StackStartPos,
Types.size());
for (size_t I = 0, E = NumPops; I != E; I++) {
if (std::get_if<Polymorphic>(&Stack.back()))
break;
Stack.pop_back();
}
return Error;
}

Expand Down Expand Up @@ -201,25 +228,6 @@ bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
return false;
}

bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
if (Level >= BrStack.size())
return typeError(ErrorLoc,
StringRef("br: invalid depth ") + std::to_string(Level));
const SmallVector<wasm::ValType, 4> &Expected =
BrStack[BrStack.size() - Level - 1];
return checkTypes(ErrorLoc, Expected);
return false;
}

bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
if (!PopVals)
BrStack.pop_back();

if (PopVals)
return popTypes(ErrorLoc, LastSig.Returns);
return checkTypes(ErrorLoc, LastSig.Returns);
}

bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
const wasm::WasmSignature &Sig) {
bool Error = popTypes(ErrorLoc, Sig.Params);
Expand Down Expand Up @@ -308,10 +316,11 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
return false;
}

bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
bool Error = popTypes(ErrorLoc, ReturnTypes, ExactMatch);
Unreachable = true;
return Error;
bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc,
bool ExactMatch) {
assert(!BlockInfoStack.empty());
const auto &FuncInfo = BlockInfoStack[0];
return checkTypes(ErrorLoc, FuncInfo.Sig.Returns, ExactMatch);
}

bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
Expand Down Expand Up @@ -453,51 +462,90 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
}

if (Name == "try" || Name == "block" || Name == "loop" || Name == "if") {
if (Name == "loop")
BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
else
BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
return true;
return false;
bool Error = Name == "if" && popType(ErrorLoc, wasm::ValType::I32);
// Pop block input parameters and check their types are correct
Error |= popTypes(ErrorLoc, LastSig.Params);
// Push a new block info
BlockInfoStack.push_back({LastSig, Stack.size(), Name == "loop"});
// Push back block input parameters
pushTypes(LastSig.Params);
return Error;
}

if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
Name == "else" || Name == "end_try" || Name == "catch" ||
Name == "catch_all" || Name == "delegate") {
bool Error = checkEnd(ErrorLoc, Name == "else" || Name == "catch" ||
Name == "catch_all");
Unreachable = false;
if (Name == "catch") {
assert(!BlockInfoStack.empty());
// Check if the types on the stack match with the block return type
const auto &LastBlockInfo = BlockInfoStack.back();
bool Error = checkTypes(ErrorLoc, LastBlockInfo.Sig.Returns, true);
// Pop all types added to the stack for the current block level
Stack.truncate(LastBlockInfo.StackStartPos);
if (Name == "else") {
// 'else' expects the block input parameters to be on the stack, in the
// same way we entered 'if'
pushTypes(LastBlockInfo.Sig.Params);
} else if (Name == "catch") {
// 'catch' instruction pushes values whose types are specified in the
// tag's 'params' part
const wasm::WasmSignature *Sig = nullptr;
if (!getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
// catch instruction pushes values whose types are specified in the
// tag's "params" part
pushTypes(Sig->Params);
else
Error = true;
} else if (Name == "catch_all") {
// 'catch_all' does not push anything onto the stack
} else {
// For normal end markers, push block return value types onto the stack
// and pop the block info
pushTypes(LastBlockInfo.Sig.Returns);
BlockInfoStack.pop_back();
}
return Error;
}

if (Name == "br") {
if (Name == "br" || Name == "br_if") {
bool Error = false;
if (Name == "br_if")
Error |= popType(ErrorLoc, wasm::ValType::I32); // cond
const MCOperand &Operand = Inst.getOperand(0);
if (!Operand.isImm())
return true;
return checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm()));
if (Operand.isImm()) {
unsigned Level = Operand.getImm();
if (Level < BlockInfoStack.size()) {
const auto &DestBlockInfo =
BlockInfoStack[BlockInfoStack.size() - Level - 1];
if (DestBlockInfo.IsLoop)
Error |= checkTypes(ErrorLoc, DestBlockInfo.Sig.Params, false);
else
Error |= checkTypes(ErrorLoc, DestBlockInfo.Sig.Returns, false);
} else {
Error = typeError(ErrorLoc, StringRef("br: invalid depth ") +
std::to_string(Level));
}
} else {
Error =
typeError(Operands[1]->getStartLoc(), "depth should be an integer");
}
if (Name == "br")
pushType(Polymorphic{});
return Error;
}

if (Name == "return") {
return endOfFunction(ErrorLoc, false);
bool Error = endOfFunction(ErrorLoc, false);
pushType(Polymorphic{});
return Error;
}

if (Name == "call_indirect" || Name == "return_call_indirect") {
// Function value.
bool Error = popType(ErrorLoc, wasm::ValType::I32);
Error |= checkSig(ErrorLoc, LastSig);
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc, false))
return true;
if (Name == "return_call_indirect") {
Error |= endOfFunction(ErrorLoc, false);
pushType(Polymorphic{});
}
return Error;
}

Expand All @@ -509,13 +557,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
Error |= checkSig(ErrorLoc, *Sig);
else
Error = true;
if (Name == "return_call" && endOfFunction(ErrorLoc, false))
return true;
if (Name == "return_call") {
Error |= endOfFunction(ErrorLoc, false);
pushType(Polymorphic{});
}
return Error;
}

if (Name == "unreachable") {
Unreachable = true;
pushType(Polymorphic{});
return false;
}

Expand All @@ -526,11 +576,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
}

if (Name == "throw") {
bool Error = false;
const wasm::WasmSignature *Sig = nullptr;
if (!getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
return checkSig(ErrorLoc, *Sig);
return true;
Error |= checkSig(ErrorLoc, *Sig);
else
Error = true;
pushType(Polymorphic{});
return Error;
}

// The current instruction is a stack instruction which doesn't have
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ class WebAssemblyAsmTypeCheck final {

struct Ref : public std::monostate {};
struct Any : public std::monostate {};
using StackType = std::variant<wasm::ValType, Ref, Any>;
struct Polymorphic : public std::monostate {};
using StackType = std::variant<wasm::ValType, Ref, Any, Polymorphic>;
SmallVector<StackType, 16> Stack;
SmallVector<SmallVector<wasm::ValType, 4>, 8> BrStack;
struct BlockInfo {
wasm::WasmSignature Sig;
size_t StackStartPos;
bool IsLoop;
};
SmallVector<BlockInfo, 8> BlockInfoStack;
SmallVector<wasm::ValType, 16> LocalTypes;
SmallVector<wasm::ValType, 4> ReturnTypes;
wasm::WasmSignature LastSig;
bool Unreachable = false;
bool Is64;

// checkTypes checks 'Types' against the value stack. popTypes checks 'Types'
Expand Down Expand Up @@ -68,8 +72,6 @@ class WebAssemblyAsmTypeCheck final {
void dumpTypeStack(Twine Msg);
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
bool checkBr(SMLoc ErrorLoc, size_t Level);
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
bool getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
const MCSymbolRefExpr *&SymRef);
Expand All @@ -91,10 +93,8 @@ class WebAssemblyAsmTypeCheck final {

void clear() {
Stack.clear();
BrStack.clear();
BlockInfoStack.clear();
LocalTypes.clear();
ReturnTypes.clear();
Unreachable = false;
}
};

Expand Down
8 changes: 6 additions & 2 deletions llvm/test/MC/WebAssembly/basic-assembly.s
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ test0:
i32.const 3
end_block # "switch" exit.
if # void
i32.const 0
if i32
i32.const 0
end_if
drop
else
end_if
drop
block void
i32.const 2
return
Expand Down Expand Up @@ -222,11 +224,13 @@ empty_exnref_table:
# CHECK-NEXT: i32.const 3
# CHECK-NEXT: end_block # label2:
# CHECK-NEXT: if
# CHECK-NEXT: i32.const 0
# CHECK-NEXT: if i32
# CHECK-NEXT: i32.const 0
# CHECK-NEXT: end_if
# CHECK-NEXT: drop
# CHECK-NEXT: else
# CHECK-NEXT: end_if
# CHECK-NEXT: drop
# CHECK-NEXT: block
# CHECK-NEXT: i32.const 2
# CHECK-NEXT: return
Expand Down
Loading
Loading