Skip to content

Commit 368cb42

Browse files
committed
[ASAN] Support memory checks on scalable vector typed masked load and store
This takes the approach of using the loop based formation for scalable vectors only. We could potentially use the loop form for fixed vectors only, but we'd loose the unroll and specialize on constant vector logic which is already present. I don't have a strong opinion on whether the existing logic is worthwhile, I kept it mostly to minimize test churn. Worth noting is that there is a better lowering available. The plain vector lowering appears to check only the first and last byte. By analogy, we should be able to check only the first active and last active byte in the masked op. This is a more invasive change to asan, and I decided simply supporting scalable vectors at all was a better starting place. Differential Revision: https://reviews.llvm.org/D145198
1 parent 100a3c3 commit 368cb42

File tree

2 files changed

+150
-25
lines changed

2 files changed

+150
-25
lines changed

llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,43 +1439,103 @@ static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I,
14391439
IsWrite, nullptr, UseCalls, Exp);
14401440
}
14411441

1442+
static void SplitBlockAndInsertSimpleForLoop(Value *End,
1443+
Instruction *SplitBefore,
1444+
Instruction *&BodyIP,
1445+
Value *&Index) {
1446+
BasicBlock *LoopPred = SplitBefore->getParent();
1447+
BasicBlock *LoopBody = SplitBlock(SplitBefore->getParent(), SplitBefore);
1448+
BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore);
1449+
1450+
auto *Ty = End->getType();
1451+
auto &DL = SplitBefore->getModule()->getDataLayout();
1452+
const unsigned Bitwidth = DL.getTypeSizeInBits(Ty);
1453+
1454+
IRBuilder<> Builder(LoopBody->getTerminator());
1455+
auto *IV = Builder.CreatePHI(Ty, 2, "iv");
1456+
auto *IVNext =
1457+
Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next",
1458+
/*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
1459+
auto *IVCheck = Builder.CreateICmpEQ(IVNext, End,
1460+
IV->getName() + ".check");
1461+
Builder.CreateCondBr(IVCheck, LoopExit, LoopBody);
1462+
LoopBody->getTerminator()->eraseFromParent();
1463+
1464+
// Populate the IV PHI.
1465+
IV->addIncoming(ConstantInt::get(Ty, 0), LoopPred);
1466+
IV->addIncoming(IVNext, LoopBody);
1467+
1468+
BodyIP = LoopBody->getFirstNonPHI();
1469+
Index = IV;
1470+
}
1471+
1472+
14421473
static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass,
14431474
const DataLayout &DL, Type *IntptrTy,
14441475
Value *Mask, Instruction *I,
14451476
Value *Addr, MaybeAlign Alignment,
14461477
unsigned Granularity, Type *OpType,
14471478
bool IsWrite, Value *SizeArgument,
14481479
bool UseCalls, uint32_t Exp) {
1449-
auto *VTy = cast<FixedVectorType>(OpType);
1450-
uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
1451-
unsigned Num = VTy->getNumElements();
1480+
auto *VTy = cast<VectorType>(OpType);
1481+
1482+
TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
14521483
auto Zero = ConstantInt::get(IntptrTy, 0);
1453-
for (unsigned Idx = 0; Idx < Num; ++Idx) {
1454-
Value *InstrumentedAddress = nullptr;
1455-
Instruction *InsertBefore = I;
1456-
if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
1457-
// dyn_cast as we might get UndefValue
1458-
if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
1459-
if (Masked->isZero())
1460-
// Mask is constant false, so no instrumentation needed.
1461-
continue;
1462-
// If we have a true or undef value, fall through to doInstrumentAddress
1463-
// with InsertBefore == I
1484+
1485+
// For fixed length vectors, it's legal to fallthrough into the generic loop
1486+
// lowering below, but we chose to unroll and specialize instead. We might want
1487+
// to revisit this heuristic decision.
1488+
if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
1489+
unsigned Num = FVTy->getNumElements();
1490+
for (unsigned Idx = 0; Idx < Num; ++Idx) {
1491+
Value *InstrumentedAddress = nullptr;
1492+
Instruction *InsertBefore = I;
1493+
if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
1494+
// dyn_cast as we might get UndefValue
1495+
if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
1496+
if (Masked->isZero())
1497+
// Mask is constant false, so no instrumentation needed.
1498+
continue;
1499+
// If we have a true or undef value, fall through to doInstrumentAddress
1500+
// with InsertBefore == I
1501+
}
1502+
} else {
1503+
IRBuilder<> IRB(I);
1504+
Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
1505+
Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
1506+
InsertBefore = ThenTerm;
14641507
}
1465-
} else {
1466-
IRBuilder<> IRB(I);
1467-
Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
1468-
Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
1469-
InsertBefore = ThenTerm;
1470-
}
14711508

1472-
IRBuilder<> IRB(InsertBefore);
1473-
InstrumentedAddress =
1509+
IRBuilder<> IRB(InsertBefore);
1510+
InstrumentedAddress =
14741511
IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)});
1475-
doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
1476-
Granularity, TypeSize::Fixed(ElemTypeSize), IsWrite,
1477-
SizeArgument, UseCalls, Exp);
1512+
doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
1513+
Granularity, ElemTypeSize, IsWrite,
1514+
SizeArgument, UseCalls, Exp);
1515+
}
1516+
return;
14781517
}
1518+
1519+
1520+
IRBuilder<> IRB(I);
1521+
Constant *MinNumElem =
1522+
ConstantInt::get(IntptrTy, VTy->getElementCount().getKnownMinValue());
1523+
assert(isa<ScalableVectorType>(VTy) && "generalize if reused for fixed length");
1524+
Value *NumElements = IRB.CreateVScale(MinNumElem);
1525+
1526+
Instruction *BodyIP;
1527+
Value *Index;
1528+
SplitBlockAndInsertSimpleForLoop(NumElements, I, BodyIP, Index);
1529+
1530+
IRB.SetInsertPoint(BodyIP);
1531+
Value *MaskElem = IRB.CreateExtractElement(Mask, Index);
1532+
Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, BodyIP, false);
1533+
IRB.SetInsertPoint(ThenTerm);
1534+
1535+
Value *InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index});
1536+
doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), InstrumentedAddress, Alignment,
1537+
Granularity, ElemTypeSize, IsWrite, SizeArgument,
1538+
UseCalls, Exp);
14791539
}
14801540

14811541
void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,

llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,68 @@ define <4 x float> @load.v4f32.1001.after.full.load(ptr %p, <4 x float> %arg) sa
308308
%res2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %p, i32 4, <4 x i1> <i1 false, i1 false, i1 false, i1 true>, <4 x float> %arg)
309309
ret <4 x float> %res2
310310
}
311+
312+
;; Scalable vector tests
313+
;; ---------------------------
314+
declare <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
315+
declare void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float>, ptr, i32, <vscale x 4 x i1>)
316+
317+
define <vscale x 4 x float> @scalable.load.nxv4f32(ptr %p, <vscale x 4 x i1> %mask) sanitize_address {
318+
; CHECK-LABEL: @scalable.load.nxv4f32(
319+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
320+
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
321+
; CHECK-NEXT: br label [[DOTSPLIT:%.*]]
322+
; CHECK: .split:
323+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ]
324+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i1> [[MASK:%.*]], i64 [[IV]]
325+
; CHECK-NEXT: br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]]
326+
; CHECK: 4:
327+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <vscale x 4 x float>, ptr [[P:%.*]], i64 0, i64 [[IV]]
328+
; CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
329+
; CHECK-NEXT: call void @__asan_load4(i64 [[TMP6]])
330+
; CHECK-NEXT: br label [[TMP7]]
331+
; CHECK: 7:
332+
; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
333+
; CHECK-NEXT: [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]]
334+
; CHECK-NEXT: br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]]
335+
; CHECK: .split.split:
336+
; CHECK-NEXT: [[RES:%.*]] = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[P]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> undef)
337+
; CHECK-NEXT: ret <vscale x 4 x float> [[RES]]
338+
;
339+
; DISABLED-LABEL: @scalable.load.nxv4f32(
340+
; DISABLED-NEXT: [[RES:%.*]] = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[P:%.*]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> undef)
341+
; DISABLED-NEXT: ret <vscale x 4 x float> [[RES]]
342+
;
343+
%res = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %p, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> undef)
344+
ret <vscale x 4 x float> %res
345+
}
346+
347+
define void @scalable.store.nxv4f32(ptr %p, <vscale x 4 x float> %arg, <vscale x 4 x i1> %mask) sanitize_address {
348+
; CHECK-LABEL: @scalable.store.nxv4f32(
349+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
350+
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
351+
; CHECK-NEXT: br label [[DOTSPLIT:%.*]]
352+
; CHECK: .split:
353+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ]
354+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i1> [[MASK:%.*]], i64 [[IV]]
355+
; CHECK-NEXT: br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]]
356+
; CHECK: 4:
357+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <vscale x 4 x float>, ptr [[P:%.*]], i64 0, i64 [[IV]]
358+
; CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
359+
; CHECK-NEXT: call void @__asan_store4(i64 [[TMP6]])
360+
; CHECK-NEXT: br label [[TMP7]]
361+
; CHECK: 7:
362+
; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
363+
; CHECK-NEXT: [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]]
364+
; CHECK-NEXT: br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]]
365+
; CHECK: .split.split:
366+
; CHECK-NEXT: tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[ARG:%.*]], ptr [[P]], i32 4, <vscale x 4 x i1> [[MASK]])
367+
; CHECK-NEXT: ret void
368+
;
369+
; DISABLED-LABEL: @scalable.store.nxv4f32(
370+
; DISABLED-NEXT: tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[ARG:%.*]], ptr [[P:%.*]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
371+
; DISABLED-NEXT: ret void
372+
;
373+
tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %arg, ptr %p, i32 4, <vscale x 4 x i1> %mask)
374+
ret void
375+
}

0 commit comments

Comments
 (0)