Skip to content

Commit 27f66c9

Browse files
committed
[WebAssembly] Protect memory.fill and memory.copy from zero-length ranges.
WebAssembly's `memory.fill` and `memory.copy` instructions trap if the pointers are out of bounds, even if the length is zero. This is different from LLVM, which expects that it can call `memcpy` on arbitrary invalid pointers if the length is zero. To avoid spurious traps, branch around `memory.fill` and `memory.copy` when the length is zero.
1 parent 6fcea43 commit 27f66c9

File tree

6 files changed

+434
-71
lines changed

6 files changed

+434
-71
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISD.def

+4
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,7 @@ HANDLE_MEM_NODETYPE(GLOBAL_GET)
5050
HANDLE_MEM_NODETYPE(GLOBAL_SET)
5151
HANDLE_MEM_NODETYPE(TABLE_GET)
5252
HANDLE_MEM_NODETYPE(TABLE_SET)
53+
54+
// Bulk memory instructions that require branching to handle empty ranges.
55+
HANDLE_NODETYPE(MEMCPY)
56+
HANDLE_NODETYPE(MEMSET)

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

+140
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,138 @@ static MachineBasicBlock *LowerFPToInt(MachineInstr &MI, DebugLoc DL,
561561
return DoneMBB;
562562
}
563563

564+
// Lower a `MEMCPY` instruction into a CFG triangle around a `MEMORY_COPY`
565+
// instuction to handle the zero-length case.
566+
static MachineBasicBlock *LowerMemcpy(MachineInstr &MI, DebugLoc DL,
567+
MachineBasicBlock *BB,
568+
const TargetInstrInfo &TII, bool Int64) {
569+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
570+
571+
MachineOperand DstMem = MI.getOperand(0);
572+
MachineOperand SrcMem = MI.getOperand(1);
573+
MachineOperand Dst = MI.getOperand(2);
574+
MachineOperand Src = MI.getOperand(3);
575+
MachineOperand Len = MI.getOperand(4);
576+
577+
// We're going to add an extra use to `Len` to test if it's zero; that
578+
// use shouldn't be a kill, even if the original use is.
579+
MachineOperand NoKillLen = Len;
580+
NoKillLen.setIsKill(false);
581+
582+
// Decide on which `MachineInstr` opcode we're going to use.
583+
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
584+
unsigned MemoryCopy =
585+
Int64 ? WebAssembly::MEMORY_COPY_A64 : WebAssembly::MEMORY_COPY_A32;
586+
587+
// Create two new basic blocks; one for the new `memory.fill` that we can
588+
// branch over, and one for the rest of the instructions after the original
589+
// `memory.fill`.
590+
const BasicBlock *LLVMBB = BB->getBasicBlock();
591+
MachineFunction *F = BB->getParent();
592+
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
593+
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
594+
595+
MachineFunction::iterator It = ++BB->getIterator();
596+
F->insert(It, TrueMBB);
597+
F->insert(It, DoneMBB);
598+
599+
// Transfer the remainder of BB and its successor edges to DoneMBB.
600+
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
601+
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
602+
603+
// Connect the CFG edges.
604+
BB->addSuccessor(TrueMBB);
605+
BB->addSuccessor(DoneMBB);
606+
TrueMBB->addSuccessor(DoneMBB);
607+
608+
// Create a virtual register for the `Eqz` result.
609+
unsigned EqzReg;
610+
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
611+
612+
// Erase the original `memory.copy`.
613+
MI.eraseFromParent();
614+
615+
// Test if `Len` is zero.
616+
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);
617+
618+
// Insert a new `memory.copy`.
619+
BuildMI(TrueMBB, DL, TII.get(MemoryCopy))
620+
.add(DstMem)
621+
.add(SrcMem)
622+
.add(Dst)
623+
.add(Src)
624+
.add(Len);
625+
626+
// Create the CFG triangle.
627+
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
628+
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
629+
630+
return DoneMBB;
631+
}
632+
633+
// Lower a `MEMSET` instruction into a CFG triangle around a `MEMORY_FILL`
634+
// instuction to handle the zero-length case.
635+
static MachineBasicBlock *LowerMemset(MachineInstr &MI, DebugLoc DL,
636+
MachineBasicBlock *BB,
637+
const TargetInstrInfo &TII, bool Int64) {
638+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
639+
640+
MachineOperand Mem = MI.getOperand(0);
641+
MachineOperand Dst = MI.getOperand(1);
642+
MachineOperand Val = MI.getOperand(2);
643+
MachineOperand Len = MI.getOperand(3);
644+
645+
// We're going to add an extra use to `Len` to test if it's zero; that
646+
// use shouldn't be a kill, even if the original use is.
647+
MachineOperand NoKillLen = Len;
648+
NoKillLen.setIsKill(false);
649+
650+
// Decide on which `MachineInstr` opcode we're going to use.
651+
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
652+
unsigned MemoryFill =
653+
Int64 ? WebAssembly::MEMORY_FILL_A64 : WebAssembly::MEMORY_FILL_A32;
654+
655+
// Create two new basic blocks; one for the new `memory.fill` that we can
656+
// branch over, and one for the rest of the instructions after the original
657+
// `memory.fill`.
658+
const BasicBlock *LLVMBB = BB->getBasicBlock();
659+
MachineFunction *F = BB->getParent();
660+
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
661+
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
662+
663+
MachineFunction::iterator It = ++BB->getIterator();
664+
F->insert(It, TrueMBB);
665+
F->insert(It, DoneMBB);
666+
667+
// Transfer the remainder of BB and its successor edges to DoneMBB.
668+
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
669+
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
670+
671+
// Connect the CFG edges.
672+
BB->addSuccessor(TrueMBB);
673+
BB->addSuccessor(DoneMBB);
674+
TrueMBB->addSuccessor(DoneMBB);
675+
676+
// Create a virtual register for the `Eqz` result.
677+
unsigned EqzReg;
678+
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
679+
680+
// Erase the original `memory.fill`.
681+
MI.eraseFromParent();
682+
683+
// Test if `Len` is zero.
684+
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);
685+
686+
// Insert a new `memory.copy`.
687+
BuildMI(TrueMBB, DL, TII.get(MemoryFill)).add(Mem).add(Dst).add(Val).add(Len);
688+
689+
// Create the CFG triangle.
690+
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
691+
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
692+
693+
return DoneMBB;
694+
}
695+
564696
static MachineBasicBlock *
565697
LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
566698
const WebAssemblySubtarget *Subtarget,
@@ -718,6 +850,14 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
718850
case WebAssembly::FP_TO_UINT_I64_F64:
719851
return LowerFPToInt(MI, DL, BB, TII, true, true, true,
720852
WebAssembly::I64_TRUNC_U_F64);
853+
case WebAssembly::MEMCPY_A32:
854+
return LowerMemcpy(MI, DL, BB, TII, false);
855+
case WebAssembly::MEMCPY_A64:
856+
return LowerMemcpy(MI, DL, BB, TII, true);
857+
case WebAssembly::MEMSET_A32:
858+
return LowerMemset(MI, DL, BB, TII, false);
859+
case WebAssembly::MEMSET_A64:
860+
return LowerMemset(MI, DL, BB, TII, true);
721861
case WebAssembly::CALL_RESULTS:
722862
case WebAssembly::RET_CALL_RESULTS:
723863
return LowerCallResults(MI, DL, BB, Subtarget, TII);

llvm/lib/Target/WebAssembly/WebAssemblyInstrBulkMemory.td

+87-12
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,33 @@ multiclass BULK_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
2121
}
2222

2323
// Bespoke types and nodes for bulk memory ops
24+
25+
// memory.copy (may trap on empty ranges)
26+
def wasm_memory_copy_t : SDTypeProfile<0, 5,
27+
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
28+
>;
29+
def wasm_memory_copy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memory_copy_t,
30+
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;
31+
32+
// memory.copy with a branch to avoid trapping
2433
def wasm_memcpy_t : SDTypeProfile<0, 5,
2534
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
2635
>;
27-
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memcpy_t,
36+
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMCPY", wasm_memcpy_t,
2837
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;
2938

39+
// memory.fill (may trap on empty ranges)
40+
def wasm_memory_fill_t : SDTypeProfile<0, 4,
41+
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
42+
>;
43+
def wasm_memory_fill : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memory_fill_t,
44+
[SDNPHasChain, SDNPMayStore]>;
45+
46+
// memory.fill with a branch to avoid trapping
3047
def wasm_memset_t : SDTypeProfile<0, 4,
3148
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
3249
>;
33-
def wasm_memset : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memset_t,
50+
def wasm_memset : SDNode<"WebAssemblyISD::MEMSET", wasm_memset_t,
3451
[SDNPHasChain, SDNPMayStore]>;
3552

3653
multiclass BulkMemoryOps<WebAssemblyRegClass rc, string B> {
@@ -51,25 +68,83 @@ defm DATA_DROP :
5168
[],
5269
"data.drop\t$seg", "data.drop\t$seg", 0x09>;
5370

71+
}
72+
73+
defm : BulkMemoryOps<I32, "32">;
74+
defm : BulkMemoryOps<I64, "64">;
75+
76+
// Define copy/fill manually instead of using the `BulkMemoryOps` multiclass
77+
// because when a multiclass defines opcodes, it gives them anonymous names
78+
// and we need opcodes with names so that we can handle them with custom code.
79+
5480
let mayLoad = 1, mayStore = 1 in
55-
defm MEMORY_COPY_A#B :
81+
defm MEMORY_COPY_A32 :
5682
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
57-
rc:$dst, rc:$src, rc:$len),
83+
I32:$dst, I32:$src, I32:$len),
5884
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
59-
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
60-
rc:$dst, rc:$src, rc:$len
85+
[(wasm_memory_copy (i32 imm:$src_idx), (i32 imm:$dst_idx),
86+
I32:$dst, I32:$src, I32:$len
6187
)],
6288
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
6389
"memory.copy\t$src_idx, $dst_idx", 0x0a>;
6490

6591
let mayStore = 1 in
66-
defm MEMORY_FILL_A#B :
67-
BULK_I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
92+
defm MEMORY_FILL_A32 :
93+
BULK_I<(outs), (ins i32imm_op:$idx, I32:$dst, I32:$value, I32:$size),
6894
(outs), (ins i32imm_op:$idx),
69-
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
95+
[(wasm_memory_fill (i32 imm:$idx), I32:$dst, I32:$value, I32:$size)],
7096
"memory.fill\t$idx, $dst, $value, $size",
7197
"memory.fill\t$idx", 0x0b>;
72-
}
7398

74-
defm : BulkMemoryOps<I32, "32">;
75-
defm : BulkMemoryOps<I64, "64">;
99+
let mayLoad = 1, mayStore = 1 in
100+
defm MEMORY_COPY_A64 :
101+
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
102+
I64:$dst, I64:$src, I64:$len),
103+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
104+
[(wasm_memory_copy (i32 imm:$src_idx), (i32 imm:$dst_idx),
105+
I64:$dst, I64:$src, I64:$len
106+
)],
107+
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
108+
"memory.copy\t$src_idx, $dst_idx", 0x0a>;
109+
110+
let mayStore = 1 in
111+
defm MEMORY_FILL_A64 :
112+
BULK_I<(outs), (ins i32imm_op:$idx, I64:$dst, I32:$value, I64:$size),
113+
(outs), (ins i32imm_op:$idx),
114+
[(wasm_memory_fill (i32 imm:$idx), I64:$dst, I32:$value, I64:$size)],
115+
"memory.fill\t$idx, $dst, $value, $size",
116+
"memory.fill\t$idx", 0x0b>;
117+
118+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
119+
defm MEMCPY_A32 : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
120+
I32:$dst, I32:$src, I32:$len),
121+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
122+
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
123+
I32:$dst, I32:$src, I32:$len
124+
)],
125+
"", "", 0>,
126+
Requires<[HasBulkMemory]>;
127+
128+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
129+
defm MEMSET_A32 : I<(outs), (ins i32imm_op:$idx, I32:$dst, I32:$value, I32:$size),
130+
(outs), (ins i32imm_op:$idx),
131+
[(wasm_memset (i32 imm:$idx), I32:$dst, I32:$value, I32:$size)],
132+
"", "", 0>,
133+
Requires<[HasBulkMemory]>;
134+
135+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
136+
defm MEMCPY_A64 : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
137+
I64:$dst, I64:$src, I64:$len),
138+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
139+
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
140+
I64:$dst, I64:$src, I64:$len
141+
)],
142+
"", "", 0>,
143+
Requires<[HasBulkMemory]>;
144+
145+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
146+
defm MEMSET_A64 : I<(outs), (ins i32imm_op:$idx, I64:$dst, I32:$value, I64:$size),
147+
(outs), (ins i32imm_op:$idx),
148+
[(wasm_memset (i32 imm:$idx), I64:$dst, I32:$value, I64:$size)],
149+
"", "", 0>,
150+
Requires<[HasBulkMemory]>;

llvm/lib/Target/WebAssembly/WebAssemblySelectionDAGInfo.cpp

+14-5
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemcpy(
2828

2929
SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
3030
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
31-
return DAG.getNode(WebAssemblyISD::MEMORY_COPY, DL, MVT::Other,
32-
{Chain, MemIdx, MemIdx, Dst, Src,
33-
DAG.getZExtOrTrunc(Size, DL, LenMVT)});
31+
32+
// Use `MEMCPY` here instead of `MEMORY_COPY` because `memory.copy` traps
33+
// if the pointers are invalid even if the length is zero. `MEMCPY` gets
34+
// extra code to handle this in the way that LLVM IR expects.
35+
return DAG.getNode(
36+
WebAssemblyISD::MEMCPY, DL, MVT::Other,
37+
{Chain, MemIdx, MemIdx, Dst, Src, DAG.getZExtOrTrunc(Size, DL, LenMVT)});
3438
}
3539

3640
SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemmove(
@@ -52,8 +56,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemset(
5256

5357
SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
5458
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
59+
60+
// Use `MEMSET` here instead of `MEMORY_FILL` because `memory.fill` traps
61+
// if the pointers are invalid even if the length is zero. `MEMSET` gets
62+
// extra code to handle this in the way that LLVM IR expects.
63+
//
5564
// Only low byte matters for val argument, so anyext the i8
56-
return DAG.getNode(WebAssemblyISD::MEMORY_FILL, DL, MVT::Other, Chain, MemIdx,
57-
Dst, DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
65+
return DAG.getNode(WebAssemblyISD::MEMSET, DL, MVT::Other, Chain, MemIdx, Dst,
66+
DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
5867
DAG.getZExtOrTrunc(Size, DL, LenMVT));
5968
}

0 commit comments

Comments
 (0)