Skip to content

Commit 3bd9d4d

Browse files
authored
[msan] Implement shadow propagation for _mm_dp_pd, _mm_dp_ps, _mm256_dp_ps (#94875)
Default intrinsic handling was to report any uninitialized part of argument. However intrinsics use mask which allow to ignore parts of input, so it's OK to have vectors partially initialized.
1 parent 1b66306 commit 3bd9d4d

File tree

3 files changed

+103
-35
lines changed

3 files changed

+103
-35
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3289,6 +3289,75 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
32893289
setOriginForNaryOp(I);
32903290
}
32913291

3292+
// Convert `Mask` into `<n x i1>`.
3293+
Constant *createDppMask(unsigned Width, unsigned Mask) {
3294+
SmallVector<Constant *, 4> R(Width);
3295+
for (auto &M : R) {
3296+
M = ConstantInt::getBool(F.getContext(), Mask & 1);
3297+
Mask >>= 1;
3298+
}
3299+
return ConstantVector::get(R);
3300+
}
3301+
3302+
// Calculate output shadow as array of booleans `<n x i1>`, assuming if any
3303+
// arg is poisoned, entire dot product is poisoned.
3304+
Value *findDppPoisonedOutput(IRBuilder<> &IRB, Value *S, unsigned SrcMask,
3305+
unsigned DstMask) {
3306+
const unsigned Width =
3307+
cast<FixedVectorType>(S->getType())->getNumElements();
3308+
3309+
S = IRB.CreateSelect(createDppMask(Width, SrcMask), S,
3310+
Constant::getNullValue(S->getType()));
3311+
Value *SElem = IRB.CreateOrReduce(S);
3312+
Value *IsClean = IRB.CreateIsNull(SElem, "_msdpp");
3313+
Value *DstMaskV = createDppMask(Width, DstMask);
3314+
3315+
return IRB.CreateSelect(
3316+
IsClean, Constant::getNullValue(DstMaskV->getType()), DstMaskV);
3317+
}
3318+
3319+
// See `Intel Intrinsics Guide` for `_dp_p*` instructions.
3320+
//
3321+
// 2 and 4 element versions produce single scalar of dot product, and then
3322+
// puts it into elements of output vector, selected by 4 lowest bits of the
3323+
// mask. Top 4 bits of the mask control which elements of input to use for dot
3324+
// product.
3325+
//
3326+
// 8 element version mask still has only 4 bit for input, and 4 bit for output
3327+
// mask. According to the spec it just operates as 4 element version on first
3328+
// 4 elements of inputs and output, and then on last 4 elements of inputs and
3329+
// output.
3330+
void handleDppIntrinsic(IntrinsicInst &I) {
3331+
IRBuilder<> IRB(&I);
3332+
3333+
Value *S0 = getShadow(&I, 0);
3334+
Value *S1 = getShadow(&I, 1);
3335+
Value *S = IRB.CreateOr(S0, S1);
3336+
3337+
const unsigned Width =
3338+
cast<FixedVectorType>(S->getType())->getNumElements();
3339+
assert(Width == 2 || Width == 4 || Width == 8);
3340+
3341+
const unsigned Mask = cast<ConstantInt>(I.getArgOperand(2))->getZExtValue();
3342+
const unsigned SrcMask = Mask >> 4;
3343+
const unsigned DstMask = Mask & 0xf;
3344+
3345+
// Calculate shadow as `<n x i1>`.
3346+
Value *SI1 = findDppPoisonedOutput(IRB, S, SrcMask, DstMask);
3347+
if (Width == 8) {
3348+
// First 4 elements of shadow are already calculated. `makeDppShadow`
3349+
// operats on 32 bit masks, so we can just shift masks, and repeat.
3350+
SI1 = IRB.CreateOr(
3351+
SI1, findDppPoisonedOutput(IRB, S, SrcMask << 4, DstMask << 4));
3352+
}
3353+
// Extend to real size of shadow, poisoning either all or none bits of an
3354+
// element.
3355+
S = IRB.CreateSExt(SI1, S->getType(), "_msdpp");
3356+
3357+
setShadow(&I, S);
3358+
setOriginForNaryOp(I);
3359+
}
3360+
32923361
// Instrument sum-of-absolute-differences intrinsic.
32933362
void handleVectorSadIntrinsic(IntrinsicInst &I) {
32943363
const unsigned SignificantBitsPerResultElement = 16;
@@ -3644,7 +3713,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
36443713
setOriginForNaryOp(I);
36453714
}
36463715

3647-
SmallVector<int, 8> getPclmulMask(unsigned Width, bool OddElements) {
3716+
static SmallVector<int, 8> getPclmulMask(unsigned Width, bool OddElements) {
36483717
SmallVector<int, 8> Mask;
36493718
for (unsigned X = OddElements ? 1 : 0; X < Width; X += 2) {
36503719
Mask.append(2, X);
@@ -3960,6 +4029,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39604029
handleVectorPackIntrinsic(I);
39614030
break;
39624031

4032+
case Intrinsic::x86_avx_dp_ps_256:
4033+
case Intrinsic::x86_sse41_dppd:
4034+
case Intrinsic::x86_sse41_dpps:
4035+
handleDppIntrinsic(I);
4036+
break;
4037+
39634038
case Intrinsic::x86_mmx_packsswb:
39644039
case Intrinsic::x86_mmx_packuswb:
39654040
handleVectorPackIntrinsic(I, 16);

llvm/test/Instrumentation/MemorySanitizer/X86/avx-intrinsics-x86.ll

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -389,18 +389,19 @@ define <8 x float> @test_x86_avx_dp_ps_256(<8 x float> %a0, <8 x float> %a1) #0
389389
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
390390
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
391391
; CHECK-NEXT: call void @llvm.donothing()
392-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
393-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP3]], 0
394-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP2]] to i256
395-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i256 [[TMP4]], 0
396-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
397-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0]]
398-
; CHECK: 5:
399-
; CHECK-NEXT: call void @__msan_warning_noreturn()
400-
; CHECK-NEXT: unreachable
401-
; CHECK: 6:
392+
; CHECK-NEXT: [[TMP3:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
393+
; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> <i1 false, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>, <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
394+
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TMP4]])
395+
; CHECK-NEXT: [[_MSDPP:%.*]] = icmp eq i32 [[TMP5]], 0
396+
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[_MSDPP]], <8 x i1> zeroinitializer, <8 x i1> <i1 false, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>
397+
; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true, i1 true>, <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
398+
; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TMP7]])
399+
; CHECK-NEXT: [[_MSDPP1:%.*]] = icmp eq i32 [[TMP8]], 0
400+
; CHECK-NEXT: [[TMP9:%.*]] = select i1 [[_MSDPP1]], <8 x i1> zeroinitializer, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true, i1 true>
401+
; CHECK-NEXT: [[TMP10:%.*]] = or <8 x i1> [[TMP6]], [[TMP9]]
402+
; CHECK-NEXT: [[_MSDPP2:%.*]] = sext <8 x i1> [[TMP10]] to <8 x i32>
402403
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.dp.ps.256(<8 x float> [[A0:%.*]], <8 x float> [[A1:%.*]], i8 -18)
403-
; CHECK-NEXT: store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
404+
; CHECK-NEXT: store <8 x i32> [[_MSDPP2]], ptr @__msan_retval_tls, align 8
404405
; CHECK-NEXT: ret <8 x float> [[RES]]
405406
;
406407
%res = call <8 x float> @llvm.x86.avx.dp.ps.256(<8 x float> %a0, <8 x float> %a1, i8 -18) ; <<8 x float>> [#uses=1]

llvm/test/Instrumentation/MemorySanitizer/X86/sse41-intrinsics-x86.ll

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ define <2 x double> @test_x86_sse41_dppd(<2 x double> %a0, <2 x double> %a1) #0
4545
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr @__msan_param_tls, align 8
4646
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x i64>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 16) to ptr), align 8
4747
; CHECK-NEXT: call void @llvm.donothing()
48-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <2 x i64> [[TMP1]] to i128
49-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP3]], 0
50-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <2 x i64> [[TMP2]] to i128
51-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
52-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
53-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0:![0-9]+]]
54-
; CHECK: 5:
55-
; CHECK-NEXT: call void @__msan_warning_noreturn()
56-
; CHECK-NEXT: unreachable
57-
; CHECK: 6:
48+
; CHECK-NEXT: [[TMP3:%.*]] = or <2 x i64> [[TMP1]], [[TMP2]]
49+
; CHECK-NEXT: [[TMP4:%.*]] = select <2 x i1> <i1 false, i1 true>, <2 x i64> [[TMP3]], <2 x i64> zeroinitializer
50+
; CHECK-NEXT: [[TMP5:%.*]] = call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> [[TMP4]])
51+
; CHECK-NEXT: [[_MSDPP:%.*]] = icmp eq i64 [[TMP5]], 0
52+
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[_MSDPP]], <2 x i1> zeroinitializer, <2 x i1> <i1 false, i1 true>
53+
; CHECK-NEXT: [[_MSDPP1:%.*]] = sext <2 x i1> [[TMP6]] to <2 x i64>
5854
; CHECK-NEXT: [[RES:%.*]] = call <2 x double> @llvm.x86.sse41.dppd(<2 x double> [[A0:%.*]], <2 x double> [[A1:%.*]], i8 -18)
59-
; CHECK-NEXT: store <2 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
55+
; CHECK-NEXT: store <2 x i64> [[_MSDPP1]], ptr @__msan_retval_tls, align 8
6056
; CHECK-NEXT: ret <2 x double> [[RES]]
6157
;
6258
%res = call <2 x double> @llvm.x86.sse41.dppd(<2 x double> %a0, <2 x double> %a1, i8 -18) ; <<2 x double>> [#uses=1]
@@ -70,18 +66,14 @@ define <4 x float> @test_x86_sse41_dpps(<4 x float> %a0, <4 x float> %a1) #0 {
7066
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr @__msan_param_tls, align 8
7167
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 16) to ptr), align 8
7268
; CHECK-NEXT: call void @llvm.donothing()
73-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i32> [[TMP1]] to i128
74-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP3]], 0
75-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i32> [[TMP2]] to i128
76-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
77-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
78-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0]]
79-
; CHECK: 5:
80-
; CHECK-NEXT: call void @__msan_warning_noreturn()
81-
; CHECK-NEXT: unreachable
82-
; CHECK: 6:
69+
; CHECK-NEXT: [[TMP3:%.*]] = or <4 x i32> [[TMP1]], [[TMP2]]
70+
; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> <i1 false, i1 true, i1 true, i1 true>, <4 x i32> [[TMP3]], <4 x i32> zeroinitializer
71+
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP4]])
72+
; CHECK-NEXT: [[_MSDPP:%.*]] = icmp eq i32 [[TMP5]], 0
73+
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[_MSDPP]], <4 x i1> zeroinitializer, <4 x i1> <i1 false, i1 true, i1 true, i1 true>
74+
; CHECK-NEXT: [[_MSDPP1:%.*]] = sext <4 x i1> [[TMP6]] to <4 x i32>
8375
; CHECK-NEXT: [[RES:%.*]] = call <4 x float> @llvm.x86.sse41.dpps(<4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], i8 -18)
84-
; CHECK-NEXT: store <4 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
76+
; CHECK-NEXT: store <4 x i32> [[_MSDPP1]], ptr @__msan_retval_tls, align 8
8577
; CHECK-NEXT: ret <4 x float> [[RES]]
8678
;
8779
%res = call <4 x float> @llvm.x86.sse41.dpps(<4 x float> %a0, <4 x float> %a1, i8 -18) ; <<4 x float>> [#uses=1]
@@ -100,7 +92,7 @@ define <4 x float> @test_x86_sse41_insertps(<4 x float> %a0, <4 x float> %a1) #0
10092
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i32> [[TMP2]] to i128
10193
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
10294
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
103-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0]]
95+
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0:![0-9]+]]
10496
; CHECK: 5:
10597
; CHECK-NEXT: call void @__msan_warning_noreturn()
10698
; CHECK-NEXT: unreachable

0 commit comments

Comments
 (0)