Skip to content

[NVPTX] Fix the error in a pattern match in v4i8 comparisons. #81308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 69 additions & 23 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1886,10 +1886,14 @@ multiclass PRMT<ValueType T, RegisterClass RC> {
}

let hasSideEffects = false in {
defm BFE_S32 : BFE<"bfe.s32", i32, Int32Regs>;
// order is somewhat important here. signed/unsigned variants match
// the same patterns, so the first one wins. Having unsigned byte extraction
// has the benefit of always having zero in unused bits, which makes some
// optimizations easier (e.g. no need to mask them).
defm BFE_U32 : BFE<"bfe.u32", i32, Int32Regs>;
defm BFE_S64 : BFE<"bfe.s64", i64, Int64Regs>;
defm BFE_S32 : BFE<"bfe.s32", i32, Int32Regs>;
defm BFE_U64 : BFE<"bfe.u64", i64, Int64Regs>;
defm BFE_S64 : BFE<"bfe.s64", i64, Int64Regs>;

defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
Expand Down Expand Up @@ -2259,27 +2263,69 @@ def : Pat<(setueq Int1Regs:$a, Int1Regs:$b),
(NOT1 (XORb1rr Int1Regs:$a, Int1Regs:$b))>;

// comparisons of i8 extracted with BFE as i32
def: Pat<(setgt (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpGT)>;
Comment on lines -2262 to -2263
Copy link
Member Author

@Artem-B Artem-B Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the root cause of the problem. We should've used bfe($a/$b, 8, 8) instead of passing $a/$b to SETP as is.
The tests worked because $a/$b happened to be the result of bfe() we've lowered for extractement(), so PTX output looked sane.

def: Pat<(setge (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpGE)>;
def: Pat<(setlt (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpLT)>;
def: Pat<(setle (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpLE)>;

def: Pat<(setugt (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpHI)>;
def: Pat<(setuge (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpHS)>;
def: Pat<(setult (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpLO)>;
def: Pat<(setule (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpLS)>;
def: Pat<(seteq (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpEQ)>;
def: Pat<(setne (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpNE)>;
// It's faster to do comparison directly on i32 extracted by BFE,
// instead of the long conversion and sign extending.
def: Pat<(setgt (i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8))), i8)),
(i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8))), i8))),
(SETP_s32rr (BFE_S32rri $a, $oa, 8), (BFE_S32rri $b, $ob, 8), CmpGT)>;
def: Pat<(setgt (i16 (sext_inreg (trunc (bfe Int32Regs:$a, imm:$oa, 8)), i8)),
(i16 (sext_inreg (trunc (bfe Int32Regs:$b, imm:$ob, 8)), i8))),
(SETP_s32rr (BFE_S32rii $a, imm:$oa, 8), (BFE_S32rii $b, imm:$ob, 8), CmpGT)>;
def: Pat<(setge (i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8))), i8)),
(i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8))), i8))),
(SETP_s32rr (BFE_S32rri $a, $oa, 8), (BFE_S32rri $b, $ob, 8), CmpGE)>;
def: Pat<(setge (i16 (sext_inreg (trunc (bfe Int32Regs:$a, imm:$oa, 8)), i8)),
(i16 (sext_inreg (trunc (bfe Int32Regs:$b, imm:$ob, 8)), i8))),
(SETP_s32rr (BFE_S32rii $a, imm:$oa, 8), (BFE_S32rii $b, imm:$ob, 8), CmpGE)>;
def: Pat<(setlt (i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8))), i8)),
(i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8))), i8))),
(SETP_s32rr (BFE_S32rri $a, $oa, 8), (BFE_S32rri $b, $ob, 8), CmpLT)>;
def: Pat<(setlt (i16 (sext_inreg (trunc (bfe Int32Regs:$a, imm:$oa, 8)), i8)),
(i16 (sext_inreg (trunc (bfe Int32Regs:$b, imm:$ob, 8)), i8))),
(SETP_s32rr (BFE_S32rii $a, imm:$oa, 8), (BFE_S32rii $b, imm:$ob, 8), CmpLT)>;
def: Pat<(setle (i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8))), i8)),
(i16 (sext_inreg (i16 (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8))), i8))),
(SETP_s32rr (BFE_S32rri $a, $oa, 8), (BFE_S32rri $b, $ob, 8), CmpLE)>;
def: Pat<(setle (i16 (sext_inreg (trunc (bfe Int32Regs:$a, imm:$oa, 8)), i8)),
(i16 (sext_inreg (trunc (bfe Int32Regs:$b, imm:$ob, 8)), i8))),
(SETP_s32rr (BFE_S32rii $a, imm:$oa, 8), (BFE_S32rii $b, imm:$ob, 8), CmpLE)>;

def: Pat<(setugt (i16 (and (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), CmpHI)>;
def: Pat<(setugt (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, imm:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), CmpHI)>;
def: Pat<(setuge (i16 (and (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), CmpHS)>;
def: Pat<(setuge (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, imm:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), CmpHS)>;
def: Pat<(setult (i16 (and (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), CmpLO)>;
def: Pat<(setult (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, imm:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), CmpLO)>;
def: Pat<(setule (i16 (and (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), CmpLS)>;
def: Pat<(setule (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, imm:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), CmpLS)>;
def: Pat<(seteq (i16 (and (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), CmpEQ)>;
def: Pat<(seteq (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, imm:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), CmpEQ)>;
def: Pat<(setne (i16 (and (trunc (bfe Int32Regs:$a, Int32Regs:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, Int32Regs:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), CmpNE)>;
def: Pat<(setne (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),
(i16 (and (trunc (bfe Int32Regs:$b, imm:$ob, 8)), 255))),
(SETP_u32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), CmpNE)>;

// i1 compare -> i32
def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)),
Expand Down
Loading