Skip to content

Commit a268bda

Browse files
authored
[WebAssembly] Handle block and polymorphic stack in AsmTypeCheck (#110770)
This makes the type checker handle blocks with input parameters and return types, branches, and polymorphic stacks correctly. We maintain the stack of "block info", which contains its input parameter type, return type, and whether it is a loop or not. And this is used when checking the validity of the value stack at the `end` marker and all branches targeting the block. `StackType` now supports a new variant `Polymorphic`, which indicates the stack is in the polymorphic state. `Polymorphic`s are not popped even when `popType` is executed; they are only popped when the current block ends. When popping from the value stack, we ensure we don't pop more than we are allowed to at the given block level and print appropriate error messages instead. Also after a block ends, the value stack is guaranteed to have the right types based on the block return type. For example, ```wast block i32 unreachable end_block ;; You can expect to have an i32 on the stack here ``` This also adds handling for `br_if`. Previously only `br`s were checked. `checkEnd` and `checkBr` were removed and their contents have been inlined to the main `typeCheck` function, because they are called only from a single callsite. This also fixes two existing bugs in AsmParser, which were required to make the tests passing. I added Github comments about them inline. This modifies several existing invalid tests, those that passed (incorrectly) before but do not pass with the new type checker anymore. Fixes #107524.
1 parent ca57e8f commit a268bda

File tree

7 files changed

+343
-79
lines changed

7 files changed

+343
-79
lines changed

llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,9 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
498498

499499
void addBlockTypeOperand(OperandVector &Operands, SMLoc NameLoc,
500500
WebAssembly::BlockType BT) {
501-
if (BT != WebAssembly::BlockType::Void) {
501+
if (BT == WebAssembly::BlockType::Void) {
502+
TC.setLastSig(wasm::WasmSignature{});
503+
} else {
502504
wasm::WasmSignature Sig({static_cast<wasm::ValType>(BT)}, {});
503505
TC.setLastSig(Sig);
504506
NestingStack.back().Sig = Sig;
@@ -1002,7 +1004,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
10021004
auto *Signature = Ctx.createWasmSignature();
10031005
if (parseSignature(Signature))
10041006
return ParseStatus::Failure;
1005-
TC.funcDecl(*Signature);
1007+
if (CurrentState == FunctionStart)
1008+
TC.funcDecl(*Signature);
10061009
WasmSym->setSignature(Signature);
10071010
WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
10081011
TOut.emitFunctionType(WasmSym);

llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp

Lines changed: 116 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
5050

5151
void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
5252
LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
53-
ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
54-
BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
53+
BlockInfoStack.push_back({Sig, 0, false});
5554
}
5655

5756
void WebAssemblyAsmTypeCheck::localDecl(
@@ -64,14 +63,15 @@ void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
6463
}
6564

6665
bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
67-
// If we're currently in unreachable code, we suppress errors completely.
68-
if (Unreachable)
69-
return false;
7066
dumpTypeStack("current stack: ");
7167
return Parser.Error(ErrorLoc, Msg);
7268
}
7369

7470
bool WebAssemblyAsmTypeCheck::match(StackType TypeA, StackType TypeB) {
71+
// These should have been filtered out in checkTypes()
72+
assert(!std::get_if<Polymorphic>(&TypeA) &&
73+
!std::get_if<Polymorphic>(&TypeB));
74+
7575
if (TypeA == TypeB)
7676
return false;
7777
if (std::get_if<Any>(&TypeA) || std::get_if<Any>(&TypeB))
@@ -90,6 +90,10 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
9090
size_t StartPos) {
9191
SmallVector<std::string, 4> TypeStrs;
9292
for (auto I = Types.size(); I > StartPos; I--) {
93+
if (std::get_if<Polymorphic>(&Types[I - 1])) {
94+
TypeStrs.push_back("...");
95+
break;
96+
}
9397
if (std::get_if<Any>(&Types[I - 1]))
9498
TypeStrs.push_back("any");
9599
else if (std::get_if<Ref>(&Types[I - 1]))
@@ -131,29 +135,48 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
131135
bool ExactMatch) {
132136
auto StackI = Stack.size();
133137
auto TypeI = Types.size();
138+
assert(!BlockInfoStack.empty());
139+
auto BlockStackStartPos = BlockInfoStack.back().StackStartPos;
134140
bool Error = false;
141+
bool PolymorphicStack = false;
135142
// Compare elements one by one from the stack top
136-
for (; StackI > 0 && TypeI > 0; StackI--, TypeI--) {
143+
for (; StackI > BlockStackStartPos && TypeI > 0; StackI--, TypeI--) {
144+
// If the stack is polymorphic, we assume all types in 'Types' have been
145+
// compared and matched
146+
if (std::get_if<Polymorphic>(&Stack[StackI - 1])) {
147+
TypeI = 0;
148+
break;
149+
}
137150
if (match(Stack[StackI - 1], Types[TypeI - 1])) {
138151
Error = true;
139152
break;
140153
}
141154
}
155+
156+
// If the stack top is polymorphic, the stack is in the polymorphic state.
157+
if (StackI > BlockStackStartPos &&
158+
std::get_if<Polymorphic>(&Stack[StackI - 1]))
159+
PolymorphicStack = true;
160+
142161
// Even if no match failure has happened in the loop above, if not all
143162
// elements of Types has been matched, that means we don't have enough
144163
// elements on the stack.
145164
//
146165
// Also, if not all elements of the Stack has been matched and when
147-
// 'ExactMatch' is true, that means we have superfluous elements remaining on
148-
// the stack (e.g. at the end of a function).
149-
if (TypeI > 0 || (ExactMatch && StackI > 0))
166+
// 'ExactMatch' is true and the current stack is not polymorphic, that means
167+
// we have superfluous elements remaining on the stack (e.g. at the end of a
168+
// function).
169+
if (TypeI > 0 ||
170+
(ExactMatch && !PolymorphicStack && StackI > BlockStackStartPos))
150171
Error = true;
151172

152173
if (!Error)
153174
return false;
154175

155-
auto StackStartPos =
156-
ExactMatch ? 0 : std::max(0, (int)Stack.size() - (int)Types.size());
176+
auto StackStartPos = ExactMatch
177+
? BlockStackStartPos
178+
: std::max((int)BlockStackStartPos,
179+
(int)Stack.size() - (int)Types.size());
157180
return typeError(ErrorLoc, "type mismatch, expected " +
158181
getTypesString(Types, 0) + " but got " +
159182
getTypesString(Stack, StackStartPos));
@@ -169,9 +192,13 @@ bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
169192
ArrayRef<StackType> Types,
170193
bool ExactMatch) {
171194
bool Error = checkTypes(ErrorLoc, Types, ExactMatch);
172-
auto NumPops = std::min(Stack.size(), Types.size());
173-
for (size_t I = 0, E = NumPops; I != E; I++)
195+
auto NumPops = std::min(Stack.size() - BlockInfoStack.back().StackStartPos,
196+
Types.size());
197+
for (size_t I = 0, E = NumPops; I != E; I++) {
198+
if (std::get_if<Polymorphic>(&Stack.back()))
199+
break;
174200
Stack.pop_back();
201+
}
175202
return Error;
176203
}
177204

@@ -201,25 +228,6 @@ bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
201228
return false;
202229
}
203230

204-
bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
205-
if (Level >= BrStack.size())
206-
return typeError(ErrorLoc,
207-
StringRef("br: invalid depth ") + std::to_string(Level));
208-
const SmallVector<wasm::ValType, 4> &Expected =
209-
BrStack[BrStack.size() - Level - 1];
210-
return checkTypes(ErrorLoc, Expected);
211-
return false;
212-
}
213-
214-
bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
215-
if (!PopVals)
216-
BrStack.pop_back();
217-
218-
if (PopVals)
219-
return popTypes(ErrorLoc, LastSig.Returns);
220-
return checkTypes(ErrorLoc, LastSig.Returns);
221-
}
222-
223231
bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
224232
const wasm::WasmSignature &Sig) {
225233
bool Error = popTypes(ErrorLoc, Sig.Params);
@@ -309,9 +317,9 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
309317
}
310318

311319
bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
312-
bool Error = popTypes(ErrorLoc, ReturnTypes, ExactMatch);
313-
Unreachable = true;
314-
return Error;
320+
assert(!BlockInfoStack.empty());
321+
const auto &FuncInfo = BlockInfoStack[0];
322+
return checkTypes(ErrorLoc, FuncInfo.Sig.Returns, ExactMatch);
315323
}
316324

317325
bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
@@ -452,52 +460,91 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
452460
return popType(ErrorLoc, Any{});
453461
}
454462

455-
if (Name == "try" || Name == "block" || Name == "loop" || Name == "if") {
456-
if (Name == "loop")
457-
BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
458-
else
459-
BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
460-
if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
461-
return true;
462-
return false;
463+
if (Name == "block" || Name == "loop" || Name == "if" || Name == "try") {
464+
bool Error = Name == "if" && popType(ErrorLoc, wasm::ValType::I32);
465+
// Pop block input parameters and check their types are correct
466+
Error |= popTypes(ErrorLoc, LastSig.Params);
467+
// Push a new block info
468+
BlockInfoStack.push_back({LastSig, Stack.size(), Name == "loop"});
469+
// Push back block input parameters
470+
pushTypes(LastSig.Params);
471+
return Error;
463472
}
464473

465474
if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
466-
Name == "else" || Name == "end_try" || Name == "catch" ||
467-
Name == "catch_all" || Name == "delegate") {
468-
bool Error = checkEnd(ErrorLoc, Name == "else" || Name == "catch" ||
469-
Name == "catch_all");
470-
Unreachable = false;
471-
if (Name == "catch") {
475+
Name == "end_try" || Name == "delegate" || Name == "else" ||
476+
Name == "catch" || Name == "catch_all") {
477+
assert(!BlockInfoStack.empty());
478+
// Check if the types on the stack match with the block return type
479+
const auto &LastBlockInfo = BlockInfoStack.back();
480+
bool Error = checkTypes(ErrorLoc, LastBlockInfo.Sig.Returns, true);
481+
// Pop all types added to the stack for the current block level
482+
Stack.truncate(LastBlockInfo.StackStartPos);
483+
if (Name == "else") {
484+
// 'else' expects the block input parameters to be on the stack, in the
485+
// same way we entered 'if'
486+
pushTypes(LastBlockInfo.Sig.Params);
487+
} else if (Name == "catch") {
488+
// 'catch' instruction pushes values whose types are specified in the
489+
// tag's 'params' part
472490
const wasm::WasmSignature *Sig = nullptr;
473491
if (!getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
474492
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
475-
// catch instruction pushes values whose types are specified in the
476-
// tag's "params" part
477493
pushTypes(Sig->Params);
478494
else
479495
Error = true;
496+
} else if (Name == "catch_all") {
497+
// 'catch_all' does not push anything onto the stack
498+
} else {
499+
// For normal end markers, push block return value types onto the stack
500+
// and pop the block info
501+
pushTypes(LastBlockInfo.Sig.Returns);
502+
BlockInfoStack.pop_back();
480503
}
481504
return Error;
482505
}
483506

484-
if (Name == "br") {
507+
if (Name == "br" || Name == "br_if") {
508+
bool Error = false;
509+
if (Name == "br_if")
510+
Error |= popType(ErrorLoc, wasm::ValType::I32); // cond
485511
const MCOperand &Operand = Inst.getOperand(0);
486-
if (!Operand.isImm())
487-
return true;
488-
return checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm()));
512+
if (Operand.isImm()) {
513+
unsigned Level = Operand.getImm();
514+
if (Level < BlockInfoStack.size()) {
515+
const auto &DestBlockInfo =
516+
BlockInfoStack[BlockInfoStack.size() - Level - 1];
517+
if (DestBlockInfo.IsLoop)
518+
Error |= checkTypes(ErrorLoc, DestBlockInfo.Sig.Params, false);
519+
else
520+
Error |= checkTypes(ErrorLoc, DestBlockInfo.Sig.Returns, false);
521+
} else {
522+
Error = typeError(ErrorLoc, StringRef("br: invalid depth ") +
523+
std::to_string(Level));
524+
}
525+
} else {
526+
Error =
527+
typeError(Operands[1]->getStartLoc(), "depth should be an integer");
528+
}
529+
if (Name == "br")
530+
pushType(Polymorphic{});
531+
return Error;
489532
}
490533

491534
if (Name == "return") {
492-
return endOfFunction(ErrorLoc, false);
535+
bool Error = endOfFunction(ErrorLoc, false);
536+
pushType(Polymorphic{});
537+
return Error;
493538
}
494539

495540
if (Name == "call_indirect" || Name == "return_call_indirect") {
496541
// Function value.
497542
bool Error = popType(ErrorLoc, wasm::ValType::I32);
498543
Error |= checkSig(ErrorLoc, LastSig);
499-
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc, false))
500-
return true;
544+
if (Name == "return_call_indirect") {
545+
Error |= endOfFunction(ErrorLoc, false);
546+
pushType(Polymorphic{});
547+
}
501548
return Error;
502549
}
503550

@@ -509,13 +556,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
509556
Error |= checkSig(ErrorLoc, *Sig);
510557
else
511558
Error = true;
512-
if (Name == "return_call" && endOfFunction(ErrorLoc, false))
513-
return true;
559+
if (Name == "return_call") {
560+
Error |= endOfFunction(ErrorLoc, false);
561+
pushType(Polymorphic{});
562+
}
514563
return Error;
515564
}
516565

517566
if (Name == "unreachable") {
518-
Unreachable = true;
567+
pushType(Polymorphic{});
519568
return false;
520569
}
521570

@@ -526,11 +575,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
526575
}
527576

528577
if (Name == "throw") {
578+
bool Error = false;
529579
const wasm::WasmSignature *Sig = nullptr;
530580
if (!getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
531581
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
532-
return checkSig(ErrorLoc, *Sig);
533-
return true;
582+
Error |= checkSig(ErrorLoc, *Sig);
583+
else
584+
Error = true;
585+
pushType(Polymorphic{});
586+
return Error;
534587
}
535588

536589
// The current instruction is a stack instruction which doesn't have

llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,17 @@ class WebAssemblyAsmTypeCheck final {
3131

3232
struct Ref : public std::monostate {};
3333
struct Any : public std::monostate {};
34-
using StackType = std::variant<wasm::ValType, Ref, Any>;
34+
struct Polymorphic : public std::monostate {};
35+
using StackType = std::variant<wasm::ValType, Ref, Any, Polymorphic>;
3536
SmallVector<StackType, 16> Stack;
36-
SmallVector<SmallVector<wasm::ValType, 4>, 8> BrStack;
37+
struct BlockInfo {
38+
wasm::WasmSignature Sig;
39+
size_t StackStartPos;
40+
bool IsLoop;
41+
};
42+
SmallVector<BlockInfo, 8> BlockInfoStack;
3743
SmallVector<wasm::ValType, 16> LocalTypes;
38-
SmallVector<wasm::ValType, 4> ReturnTypes;
3944
wasm::WasmSignature LastSig;
40-
bool Unreachable = false;
4145
bool Is64;
4246

4347
// checkTypes checks 'Types' against the value stack. popTypes checks 'Types'
@@ -68,8 +72,6 @@ class WebAssemblyAsmTypeCheck final {
6872
void dumpTypeStack(Twine Msg);
6973
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
7074
bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
71-
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
72-
bool checkBr(SMLoc ErrorLoc, size_t Level);
7375
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
7476
bool getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
7577
const MCSymbolRefExpr *&SymRef);
@@ -91,10 +93,8 @@ class WebAssemblyAsmTypeCheck final {
9193

9294
void clear() {
9395
Stack.clear();
94-
BrStack.clear();
96+
BlockInfoStack.clear();
9597
LocalTypes.clear();
96-
ReturnTypes.clear();
97-
Unreachable = false;
9898
}
9999
};
100100

llvm/test/MC/WebAssembly/basic-assembly.s

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ test0:
8282
i32.const 3
8383
end_block # "switch" exit.
8484
if # void
85+
i32.const 0
8586
if i32
87+
i32.const 0
8688
end_if
89+
drop
8790
else
8891
end_if
89-
drop
9092
block void
9193
i32.const 2
9294
return
@@ -222,11 +224,13 @@ empty_exnref_table:
222224
# CHECK-NEXT: i32.const 3
223225
# CHECK-NEXT: end_block # label2:
224226
# CHECK-NEXT: if
227+
# CHECK-NEXT: i32.const 0
225228
# CHECK-NEXT: if i32
229+
# CHECK-NEXT: i32.const 0
226230
# CHECK-NEXT: end_if
231+
# CHECK-NEXT: drop
227232
# CHECK-NEXT: else
228233
# CHECK-NEXT: end_if
229-
# CHECK-NEXT: drop
230234
# CHECK-NEXT: block
231235
# CHECK-NEXT: i32.const 2
232236
# CHECK-NEXT: return

0 commit comments

Comments
 (0)