Skip to content

[InstCombine][X86] Only demand used bits for PSHUFB mask values #106377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 28, 2024

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Aug 28, 2024

(V)PSHUFB only uses the sign bit (for zeroing) and the lower 4 bits (to index per-lane byte 0-15) - so use SimplifyDemandedBits to ignore anything touching the remaining bits.

Fixes #106256

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2024

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-transforms

Author: Simon Pilgrim (RKSimon)

Changes

(V)PSHUFB only uses the sign bit (for zeroing) and the lower 4 bits (to index per-lane byte 0-15) - so use SimplifyDemandedBits to ignore anything touching the remaining bits.

Fixes #106256


Full diff: https://github.com/llvm/llvm-project/pull/106377.diff

3 Files Affected:

  • (modified) llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp (+6-1)
  • (modified) llvm/test/Transforms/InstCombine/X86/x86-pshufb-inseltpoison.ll (+32)
  • (modified) llvm/test/Transforms/InstCombine/X86/x86-pshufb.ll (+32)
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index 793d62ba2a8e79..42ad8905a703c4 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -2950,11 +2950,16 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
 
   case Intrinsic::x86_ssse3_pshuf_b_128:
   case Intrinsic::x86_avx2_pshuf_b:
-  case Intrinsic::x86_avx512_pshuf_b_512:
+  case Intrinsic::x86_avx512_pshuf_b_512: {
     if (Value *V = simplifyX86pshufb(II, IC.Builder)) {
       return IC.replaceInstUsesWith(II, V);
     }
+    KnownBits KnownMask(8);
+    if (IC.SimplifyDemandedBits(&II, 1, APInt(8, 0b10001111), KnownMask)) {
+      return ⅈ
+    }
     break;
+  }
 
   case Intrinsic::x86_avx_vpermilvar_ps:
   case Intrinsic::x86_avx_vpermilvar_ps_256:
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pshufb-inseltpoison.ll b/llvm/test/Transforms/InstCombine/X86/x86-pshufb-inseltpoison.ll
index 2f301e9e9c1072..f8cec9c152a08b 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pshufb-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pshufb-inseltpoison.ll
@@ -468,6 +468,38 @@ define <64 x i8> @fold_with_allpoison_elts_avx512(<64 x i8> %InVec) {
   ret <64 x i8> %1
 }
 
+; Demanded bits tests (PR106256)
+
+define <16 x i8> @demanded_bits_mask(<16 x i8> %InVec, <16 x i8> %InMask) {
+; CHECK-LABEL: @demanded_bits_mask(
+; CHECK-NEXT:    [[S:%.*]] = tail call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> [[INVEC:%.*]], <16 x i8> [[INMASK:%.*]])
+; CHECK-NEXT:    ret <16 x i8> [[S]]
+;
+  %m = or <16 x i8> %InMask, <i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112>
+  %s = tail call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> %InVec, <16 x i8> %m)
+  ret <16 x i8> %s
+}
+
+define <32 x i8> @demanded_bits_mask_avx2(<32 x i8> %InVec, <32 x i8> %InMask) {
+; CHECK-LABEL: @demanded_bits_mask_avx2(
+; CHECK-NEXT:    [[S:%.*]] = tail call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> [[INVEC:%.*]], <32 x i8> [[INMASK:%.*]])
+; CHECK-NEXT:    ret <32 x i8> [[S]]
+;
+  %m = or <32 x i8> %InMask, <i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112>
+  %s = tail call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> %InVec, <32 x i8> %m)
+  ret <32 x i8> %s
+}
+
+define <64 x i8> @demanded_bits_mask_avx512(<64 x i8> %InVec, <64 x i8> %InMask) {
+; CHECK-LABEL: @demanded_bits_mask_avx512(
+; CHECK-NEXT:    [[S:%.*]] = tail call <64 x i8> @llvm.x86.avx512.pshuf.b.512(<64 x i8> [[INVEC:%.*]], <64 x i8> [[INMASK:%.*]])
+; CHECK-NEXT:    ret <64 x i8> [[S]]
+;
+  %m = or <64 x i8> %InMask, <i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112>
+  %s = tail call <64 x i8> @llvm.x86.avx512.pshuf.b.512(<64 x i8> %InVec, <64 x i8> %m)
+  ret <64 x i8> %s
+}
+
 ; Demanded elts tests.
 
 define <16 x i8> @demanded_elts_insertion(<16 x i8> %InVec, <16 x i8> %BaseMask, i8 %M0, i8 %M15) {
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pshufb.ll b/llvm/test/Transforms/InstCombine/X86/x86-pshufb.ll
index cd90696eafac6f..fd99fd880a8098 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pshufb.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pshufb.ll
@@ -468,6 +468,38 @@ define <64 x i8> @fold_with_allundef_elts_avx512(<64 x i8> %InVec) {
   ret <64 x i8> %1
 }
 
+; Demanded bits tests (PR106256)
+
+define <16 x i8> @demanded_bits_mask(<16 x i8> %InVec, <16 x i8> %InMask) {
+; CHECK-LABEL: @demanded_bits_mask(
+; CHECK-NEXT:    [[S:%.*]] = tail call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> [[INVEC:%.*]], <16 x i8> [[INMASK:%.*]])
+; CHECK-NEXT:    ret <16 x i8> [[S]]
+;
+  %m = or <16 x i8> %InMask, <i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112>
+  %s = tail call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> %InVec, <16 x i8> %m)
+  ret <16 x i8> %s
+}
+
+define <32 x i8> @demanded_bits_mask_avx2(<32 x i8> %InVec, <32 x i8> %InMask) {
+; CHECK-LABEL: @demanded_bits_mask_avx2(
+; CHECK-NEXT:    [[S:%.*]] = tail call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> [[INVEC:%.*]], <32 x i8> [[INMASK:%.*]])
+; CHECK-NEXT:    ret <32 x i8> [[S]]
+;
+  %m = or <32 x i8> %InMask, <i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112>
+  %s = tail call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> %InVec, <32 x i8> %m)
+  ret <32 x i8> %s
+}
+
+define <64 x i8> @demanded_bits_mask_avx512(<64 x i8> %InVec, <64 x i8> %InMask) {
+; CHECK-LABEL: @demanded_bits_mask_avx512(
+; CHECK-NEXT:    [[S:%.*]] = tail call <64 x i8> @llvm.x86.avx512.pshuf.b.512(<64 x i8> [[INVEC:%.*]], <64 x i8> [[INMASK:%.*]])
+; CHECK-NEXT:    ret <64 x i8> [[S]]
+;
+  %m = or <64 x i8> %InMask, <i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112, i8 16, i8 48, i8 112, i8 112>
+  %s = tail call <64 x i8> @llvm.x86.avx512.pshuf.b.512(<64 x i8> %InVec, <64 x i8> %m)
+  ret <64 x i8> %s
+}
+
 ; Demanded elts tests.
 
 define <16 x i8> @demanded_elts_insertion(<16 x i8> %InVec, <16 x i8> %BaseMask, i8 %M0, i8 %M15) {

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Comment on lines 2958 to 2960
if (IC.SimplifyDemandedBits(&II, 1, APInt(8, 0b10001111), KnownMask)) {
return &II;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove parentheses

(V)PSHUFB only uses the sign bit (for zeroing) and the lower 4 bits (to index per-lane byte 0-15) - so use SimplifyDemandedBits to ignore anything touching the remaining bits.

Fixes llvm#106256
@RKSimon RKSimon force-pushed the x86-pshufb-demandedbits branch from 96b6b7e to 73de394 Compare August 28, 2024 12:51
@RKSimon RKSimon merged commit 51a0951 into llvm:main Aug 28, 2024
8 checks passed
@RKSimon RKSimon deleted the x86-pshufb-demandedbits branch August 28, 2024 13:57
@cvijdea-bd
Copy link

cvijdea-bd commented Aug 28, 2024

@RKSimon seems great, but it seems the same thing is also missing for vperm(b|d|w|q)? They only use the lower 4, 5, 6 bits for 128, 256, and 512 size respectively. Unlike pshufb, the high bit is ignored.

https://godbolt.org/z/M1rWh7oWf

EDIT: actually it's more complicated, the number of bits used depends on the register size and on the instruction granularity too, e.g both vpermb on 128 bits and vpermw on 256 bits use only the lower 4 bits of the mask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Ineffectual por with constant emitted for pshufb operand
4 participants