Skip to content

Commit 1355f5e

Browse files
committed
[WebAssembly] Protect memory.fill and memory.copy from zero-length ranges.
WebAssembly's `memory.fill` and `memory.copy` instructions trap if the pointers are out of bounds, even if the length is zero. This is different from LLVM, which expects that it can call `memcpy` on arbitrary invalid pointers if the length is zero. To avoid spurious traps, branch around `memory.fill` and `memory.copy` when the length is zero.
1 parent a5d919b commit 1355f5e

File tree

6 files changed

+434
-71
lines changed

6 files changed

+434
-71
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISD.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,7 @@ HANDLE_MEM_NODETYPE(GLOBAL_GET)
5454
HANDLE_MEM_NODETYPE(GLOBAL_SET)
5555
HANDLE_MEM_NODETYPE(TABLE_GET)
5656
HANDLE_MEM_NODETYPE(TABLE_SET)
57+
58+
// Bulk memory instructions that require branching to handle empty ranges.
59+
HANDLE_NODETYPE(MEMCPY)
60+
HANDLE_NODETYPE(MEMSET)

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,138 @@ static MachineBasicBlock *LowerFPToInt(MachineInstr &MI, DebugLoc DL,
568568
return DoneMBB;
569569
}
570570

571+
// Lower a `MEMCPY` instruction into a CFG triangle around a `MEMORY_COPY`
572+
// instuction to handle the zero-length case.
573+
static MachineBasicBlock *LowerMemcpy(MachineInstr &MI, DebugLoc DL,
574+
MachineBasicBlock *BB,
575+
const TargetInstrInfo &TII, bool Int64) {
576+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
577+
578+
MachineOperand DstMem = MI.getOperand(0);
579+
MachineOperand SrcMem = MI.getOperand(1);
580+
MachineOperand Dst = MI.getOperand(2);
581+
MachineOperand Src = MI.getOperand(3);
582+
MachineOperand Len = MI.getOperand(4);
583+
584+
// We're going to add an extra use to `Len` to test if it's zero; that
585+
// use shouldn't be a kill, even if the original use is.
586+
MachineOperand NoKillLen = Len;
587+
NoKillLen.setIsKill(false);
588+
589+
// Decide on which `MachineInstr` opcode we're going to use.
590+
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
591+
unsigned MemoryCopy =
592+
Int64 ? WebAssembly::MEMORY_COPY_A64 : WebAssembly::MEMORY_COPY_A32;
593+
594+
// Create two new basic blocks; one for the new `memory.fill` that we can
595+
// branch over, and one for the rest of the instructions after the original
596+
// `memory.fill`.
597+
const BasicBlock *LLVMBB = BB->getBasicBlock();
598+
MachineFunction *F = BB->getParent();
599+
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
600+
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
601+
602+
MachineFunction::iterator It = ++BB->getIterator();
603+
F->insert(It, TrueMBB);
604+
F->insert(It, DoneMBB);
605+
606+
// Transfer the remainder of BB and its successor edges to DoneMBB.
607+
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
608+
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
609+
610+
// Connect the CFG edges.
611+
BB->addSuccessor(TrueMBB);
612+
BB->addSuccessor(DoneMBB);
613+
TrueMBB->addSuccessor(DoneMBB);
614+
615+
// Create a virtual register for the `Eqz` result.
616+
unsigned EqzReg;
617+
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
618+
619+
// Erase the original `memory.copy`.
620+
MI.eraseFromParent();
621+
622+
// Test if `Len` is zero.
623+
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);
624+
625+
// Insert a new `memory.copy`.
626+
BuildMI(TrueMBB, DL, TII.get(MemoryCopy))
627+
.add(DstMem)
628+
.add(SrcMem)
629+
.add(Dst)
630+
.add(Src)
631+
.add(Len);
632+
633+
// Create the CFG triangle.
634+
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
635+
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
636+
637+
return DoneMBB;
638+
}
639+
640+
// Lower a `MEMSET` instruction into a CFG triangle around a `MEMORY_FILL`
641+
// instuction to handle the zero-length case.
642+
static MachineBasicBlock *LowerMemset(MachineInstr &MI, DebugLoc DL,
643+
MachineBasicBlock *BB,
644+
const TargetInstrInfo &TII, bool Int64) {
645+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
646+
647+
MachineOperand Mem = MI.getOperand(0);
648+
MachineOperand Dst = MI.getOperand(1);
649+
MachineOperand Val = MI.getOperand(2);
650+
MachineOperand Len = MI.getOperand(3);
651+
652+
// We're going to add an extra use to `Len` to test if it's zero; that
653+
// use shouldn't be a kill, even if the original use is.
654+
MachineOperand NoKillLen = Len;
655+
NoKillLen.setIsKill(false);
656+
657+
// Decide on which `MachineInstr` opcode we're going to use.
658+
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
659+
unsigned MemoryFill =
660+
Int64 ? WebAssembly::MEMORY_FILL_A64 : WebAssembly::MEMORY_FILL_A32;
661+
662+
// Create two new basic blocks; one for the new `memory.fill` that we can
663+
// branch over, and one for the rest of the instructions after the original
664+
// `memory.fill`.
665+
const BasicBlock *LLVMBB = BB->getBasicBlock();
666+
MachineFunction *F = BB->getParent();
667+
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
668+
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
669+
670+
MachineFunction::iterator It = ++BB->getIterator();
671+
F->insert(It, TrueMBB);
672+
F->insert(It, DoneMBB);
673+
674+
// Transfer the remainder of BB and its successor edges to DoneMBB.
675+
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
676+
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
677+
678+
// Connect the CFG edges.
679+
BB->addSuccessor(TrueMBB);
680+
BB->addSuccessor(DoneMBB);
681+
TrueMBB->addSuccessor(DoneMBB);
682+
683+
// Create a virtual register for the `Eqz` result.
684+
unsigned EqzReg;
685+
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
686+
687+
// Erase the original `memory.fill`.
688+
MI.eraseFromParent();
689+
690+
// Test if `Len` is zero.
691+
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);
692+
693+
// Insert a new `memory.copy`.
694+
BuildMI(TrueMBB, DL, TII.get(MemoryFill)).add(Mem).add(Dst).add(Val).add(Len);
695+
696+
// Create the CFG triangle.
697+
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
698+
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
699+
700+
return DoneMBB;
701+
}
702+
571703
static MachineBasicBlock *
572704
LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
573705
const WebAssemblySubtarget *Subtarget,
@@ -725,6 +857,14 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
725857
case WebAssembly::FP_TO_UINT_I64_F64:
726858
return LowerFPToInt(MI, DL, BB, TII, true, true, true,
727859
WebAssembly::I64_TRUNC_U_F64);
860+
case WebAssembly::MEMCPY_A32:
861+
return LowerMemcpy(MI, DL, BB, TII, false);
862+
case WebAssembly::MEMCPY_A64:
863+
return LowerMemcpy(MI, DL, BB, TII, true);
864+
case WebAssembly::MEMSET_A32:
865+
return LowerMemset(MI, DL, BB, TII, false);
866+
case WebAssembly::MEMSET_A64:
867+
return LowerMemset(MI, DL, BB, TII, true);
728868
case WebAssembly::CALL_RESULTS:
729869
case WebAssembly::RET_CALL_RESULTS:
730870
return LowerCallResults(MI, DL, BB, Subtarget, TII);

llvm/lib/Target/WebAssembly/WebAssemblyInstrBulkMemory.td

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,33 @@ multiclass BULK_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
2121
}
2222

2323
// Bespoke types and nodes for bulk memory ops
24+
25+
// memory.copy (may trap on empty ranges)
26+
def wasm_memory_copy_t : SDTypeProfile<0, 5,
27+
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
28+
>;
29+
def wasm_memory_copy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memory_copy_t,
30+
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;
31+
32+
// memory.copy with a branch to avoid trapping
2433
def wasm_memcpy_t : SDTypeProfile<0, 5,
2534
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
2635
>;
27-
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memcpy_t,
36+
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMCPY", wasm_memcpy_t,
2837
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;
2938

39+
// memory.fill (may trap on empty ranges)
40+
def wasm_memory_fill_t : SDTypeProfile<0, 4,
41+
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
42+
>;
43+
def wasm_memory_fill : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memory_fill_t,
44+
[SDNPHasChain, SDNPMayStore]>;
45+
46+
// memory.fill with a branch to avoid trapping
3047
def wasm_memset_t : SDTypeProfile<0, 4,
3148
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
3249
>;
33-
def wasm_memset : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memset_t,
50+
def wasm_memset : SDNode<"WebAssemblyISD::MEMSET", wasm_memset_t,
3451
[SDNPHasChain, SDNPMayStore]>;
3552

3653
multiclass BulkMemoryOps<WebAssemblyRegClass rc, string B> {
@@ -51,25 +68,83 @@ defm DATA_DROP :
5168
[],
5269
"data.drop\t$seg", "data.drop\t$seg", 0x09>;
5370

71+
}
72+
73+
defm : BulkMemoryOps<I32, "32">;
74+
defm : BulkMemoryOps<I64, "64">;
75+
76+
// Define copy/fill manually instead of using the `BulkMemoryOps` multiclass
77+
// because when a multiclass defines opcodes, it gives them anonymous names
78+
// and we need opcodes with names so that we can handle them with custom code.
79+
5480
let mayLoad = 1, mayStore = 1 in
55-
defm MEMORY_COPY_A#B :
81+
defm MEMORY_COPY_A32 :
5682
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
57-
rc:$dst, rc:$src, rc:$len),
83+
I32:$dst, I32:$src, I32:$len),
5884
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
59-
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
60-
rc:$dst, rc:$src, rc:$len
85+
[(wasm_memory_copy (i32 imm:$src_idx), (i32 imm:$dst_idx),
86+
I32:$dst, I32:$src, I32:$len
6187
)],
6288
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
6389
"memory.copy\t$src_idx, $dst_idx", 0x0a>;
6490

6591
let mayStore = 1 in
66-
defm MEMORY_FILL_A#B :
67-
BULK_I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
92+
defm MEMORY_FILL_A32 :
93+
BULK_I<(outs), (ins i32imm_op:$idx, I32:$dst, I32:$value, I32:$size),
6894
(outs), (ins i32imm_op:$idx),
69-
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
95+
[(wasm_memory_fill (i32 imm:$idx), I32:$dst, I32:$value, I32:$size)],
7096
"memory.fill\t$idx, $dst, $value, $size",
7197
"memory.fill\t$idx", 0x0b>;
72-
}
7398

74-
defm : BulkMemoryOps<I32, "32">;
75-
defm : BulkMemoryOps<I64, "64">;
99+
let mayLoad = 1, mayStore = 1 in
100+
defm MEMORY_COPY_A64 :
101+
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
102+
I64:$dst, I64:$src, I64:$len),
103+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
104+
[(wasm_memory_copy (i32 imm:$src_idx), (i32 imm:$dst_idx),
105+
I64:$dst, I64:$src, I64:$len
106+
)],
107+
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
108+
"memory.copy\t$src_idx, $dst_idx", 0x0a>;
109+
110+
let mayStore = 1 in
111+
defm MEMORY_FILL_A64 :
112+
BULK_I<(outs), (ins i32imm_op:$idx, I64:$dst, I32:$value, I64:$size),
113+
(outs), (ins i32imm_op:$idx),
114+
[(wasm_memory_fill (i32 imm:$idx), I64:$dst, I32:$value, I64:$size)],
115+
"memory.fill\t$idx, $dst, $value, $size",
116+
"memory.fill\t$idx", 0x0b>;
117+
118+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
119+
defm MEMCPY_A32 : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
120+
I32:$dst, I32:$src, I32:$len),
121+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
122+
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
123+
I32:$dst, I32:$src, I32:$len
124+
)],
125+
"", "", 0>,
126+
Requires<[HasBulkMemory]>;
127+
128+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
129+
defm MEMSET_A32 : I<(outs), (ins i32imm_op:$idx, I32:$dst, I32:$value, I32:$size),
130+
(outs), (ins i32imm_op:$idx),
131+
[(wasm_memset (i32 imm:$idx), I32:$dst, I32:$value, I32:$size)],
132+
"", "", 0>,
133+
Requires<[HasBulkMemory]>;
134+
135+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
136+
defm MEMCPY_A64 : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
137+
I64:$dst, I64:$src, I64:$len),
138+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
139+
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
140+
I64:$dst, I64:$src, I64:$len
141+
)],
142+
"", "", 0>,
143+
Requires<[HasBulkMemory]>;
144+
145+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
146+
defm MEMSET_A64 : I<(outs), (ins i32imm_op:$idx, I64:$dst, I32:$value, I64:$size),
147+
(outs), (ins i32imm_op:$idx),
148+
[(wasm_memset (i32 imm:$idx), I64:$dst, I32:$value, I64:$size)],
149+
"", "", 0>,
150+
Requires<[HasBulkMemory]>;

llvm/lib/Target/WebAssembly/WebAssemblySelectionDAGInfo.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemcpy(
2828

2929
SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
3030
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
31-
return DAG.getNode(WebAssemblyISD::MEMORY_COPY, DL, MVT::Other,
32-
{Chain, MemIdx, MemIdx, Dst, Src,
33-
DAG.getZExtOrTrunc(Size, DL, LenMVT)});
31+
32+
// Use `MEMCPY` here instead of `MEMORY_COPY` because `memory.copy` traps
33+
// if the pointers are invalid even if the length is zero. `MEMCPY` gets
34+
// extra code to handle this in the way that LLVM IR expects.
35+
return DAG.getNode(
36+
WebAssemblyISD::MEMCPY, DL, MVT::Other,
37+
{Chain, MemIdx, MemIdx, Dst, Src, DAG.getZExtOrTrunc(Size, DL, LenMVT)});
3438
}
3539

3640
SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemmove(
@@ -52,8 +56,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemset(
5256

5357
SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
5458
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
59+
60+
// Use `MEMSET` here instead of `MEMORY_FILL` because `memory.fill` traps
61+
// if the pointers are invalid even if the length is zero. `MEMSET` gets
62+
// extra code to handle this in the way that LLVM IR expects.
63+
//
5564
// Only low byte matters for val argument, so anyext the i8
56-
return DAG.getNode(WebAssemblyISD::MEMORY_FILL, DL, MVT::Other, Chain, MemIdx,
57-
Dst, DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
65+
return DAG.getNode(WebAssemblyISD::MEMSET, DL, MVT::Other, Chain, MemIdx, Dst,
66+
DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
5867
DAG.getZExtOrTrunc(Size, DL, LenMVT));
5968
}

0 commit comments

Comments
 (0)