Skip to content

[LLVM][SVE] Extend dup(extract_elt(v,i)) isel patterns to cover more combinations. #115189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 104 additions & 21 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,57 @@ class SVEType<ValueType VT> {
!eq(VT, nxv8f16): nxv2f16,
!eq(VT, nxv8bf16): nxv2bf16,
true : untyped);

// The 64-bit vector subreg of VT.
ValueType DSub = !cond(
!eq(VT, nxv16i8): v8i8,
!eq(VT, nxv8i16): v4i16,
!eq(VT, nxv4i32): v2i32,
!eq(VT, nxv2i64): v1i64,
!eq(VT, nxv2f16): v4f16,
!eq(VT, nxv4f16): v4f16,
!eq(VT, nxv8f16): v4f16,
!eq(VT, nxv2f32): v2f32,
!eq(VT, nxv4f32): v2f32,
!eq(VT, nxv2f64): v1f64,
!eq(VT, nxv2bf16): v4bf16,
!eq(VT, nxv4bf16): v4bf16,
!eq(VT, nxv8bf16): v4bf16,
true : untyped);

// The 128-bit vector subreg of VT.
ValueType ZSub = !cond(
!eq(VT, nxv16i8): v16i8,
!eq(VT, nxv8i16): v8i16,
!eq(VT, nxv4i32): v4i32,
!eq(VT, nxv2i64): v2i64,
!eq(VT, nxv2f16): v8f16,
!eq(VT, nxv4f16): v8f16,
!eq(VT, nxv8f16): v8f16,
!eq(VT, nxv2f32): v4f32,
!eq(VT, nxv4f32): v4f32,
!eq(VT, nxv2f64): v2f64,
!eq(VT, nxv2bf16): v8bf16,
!eq(VT, nxv4bf16): v8bf16,
!eq(VT, nxv8bf16): v8bf16,
true : untyped);

// The legal scalar used to hold a vector element.
ValueType EltAsScalar = !cond(
!eq(VT, nxv16i8): i32,
!eq(VT, nxv8i16): i32,
!eq(VT, nxv4i32): i32,
!eq(VT, nxv2i64): i64,
!eq(VT, nxv2f16): f16,
!eq(VT, nxv4f16): f16,
!eq(VT, nxv8f16): f16,
!eq(VT, nxv2f32): f32,
!eq(VT, nxv4f32): f32,
!eq(VT, nxv2f64): f64,
!eq(VT, nxv2bf16): bf16,
!eq(VT, nxv4bf16): bf16,
!eq(VT, nxv8bf16): bf16,
true : untyped);
}

def SDT_AArch64Setcc : SDTypeProfile<1, 4, [
Expand Down Expand Up @@ -1402,29 +1453,61 @@ multiclass sve_int_perm_dup_i<string asm> {
def : InstAlias<"mov $Zd, $Qn",
(!cast<Instruction>(NAME # _Q) ZPR128:$Zd, FPR128asZPR:$Qn, 0), 2>;

// Duplicate extracted element of vector into all vector elements
// Duplicate an extracted vector element across a vector.

def : Pat<(nxv16i8 (splat_vector (i32 (vector_extract (nxv16i8 ZPR:$vec), sve_elm_idx_extdup_b:$index)))),
(!cast<Instruction>(NAME # _B) ZPR:$vec, sve_elm_idx_extdup_b:$index)>;
def : Pat<(nxv8i16 (splat_vector (i32 (vector_extract (nxv8i16 ZPR:$vec), sve_elm_idx_extdup_h:$index)))),
(!cast<Instruction>(NAME # _H) ZPR:$vec, sve_elm_idx_extdup_h:$index)>;
def : Pat<(nxv4i32 (splat_vector (i32 (vector_extract (nxv4i32 ZPR:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) ZPR:$vec, sve_elm_idx_extdup_s:$index)>;
def : Pat<(nxv2i64 (splat_vector (i64 (vector_extract (nxv2i64 ZPR:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) ZPR:$vec, sve_elm_idx_extdup_d:$index)>;
def : Pat<(nxv8f16 (splat_vector (f16 (vector_extract (nxv8f16 ZPR:$vec), sve_elm_idx_extdup_h:$index)))),
(!cast<Instruction>(NAME # _H) ZPR:$vec, sve_elm_idx_extdup_h:$index)>;
def : Pat<(nxv8bf16 (splat_vector (bf16 (vector_extract (nxv8bf16 ZPR:$vec), sve_elm_idx_extdup_h:$index)))),
(!cast<Instruction>(NAME # _H) ZPR:$vec, sve_elm_idx_extdup_h:$index)>;
def : Pat<(nxv4f16 (splat_vector (f16 (vector_extract (nxv4f16 ZPR:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) ZPR:$vec, sve_elm_idx_extdup_s:$index)>;
def : Pat<(nxv2f16 (splat_vector (f16 (vector_extract (nxv2f16 ZPR:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) ZPR:$vec, sve_elm_idx_extdup_d:$index)>;
def : Pat<(nxv4f32 (splat_vector (f32 (vector_extract (nxv4f32 ZPR:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) ZPR:$vec, sve_elm_idx_extdup_s:$index)>;
def : Pat<(nxv2f32 (splat_vector (f32 (vector_extract (nxv2f32 ZPR:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) ZPR:$vec, sve_elm_idx_extdup_d:$index)>;
def : Pat<(nxv2f64 (splat_vector (f64 (vector_extract (nxv2f64 ZPR:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) ZPR:$vec, sve_elm_idx_extdup_d:$index)>;
def : Pat<(nxv16i8 (splat_vector (i32 (vector_extract (v16i8 V128:$vec), sve_elm_idx_extdup_b:$index)))),
(!cast<Instruction>(NAME # _B) (SUBREG_TO_REG (i64 0), $vec, zsub), sve_elm_idx_extdup_b:$index)>;
def : Pat<(nxv16i8 (splat_vector (i32 (vector_extract (v8i8 V64:$vec), sve_elm_idx_extdup_b:$index)))),
(!cast<Instruction>(NAME # _B) (SUBREG_TO_REG (i64 0), $vec, dsub), sve_elm_idx_extdup_b:$index)>;

foreach VT = [nxv8i16, nxv2f16, nxv4f16, nxv8f16, nxv2bf16, nxv4bf16, nxv8bf16] in {
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.Packed ZPR:$vec), sve_elm_idx_extdup_h:$index)))),
(!cast<Instruction>(NAME # _H) ZPR:$vec, sve_elm_idx_extdup_h:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.ZSub V128:$vec), sve_elm_idx_extdup_h:$index)))),
(!cast<Instruction>(NAME # _H) (SUBREG_TO_REG (i64 0), $vec, zsub), sve_elm_idx_extdup_h:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.DSub V64:$vec), sve_elm_idx_extdup_h:$index)))),
(!cast<Instruction>(NAME # _H) (SUBREG_TO_REG (i64 0), $vec, dsub), sve_elm_idx_extdup_h:$index)>;
}

foreach VT = [nxv4i32, nxv2f32, nxv4f32 ] in {
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.Packed ZPR:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) ZPR:$vec, sve_elm_idx_extdup_s:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.ZSub V128:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) (SUBREG_TO_REG (i64 0), $vec, zsub), sve_elm_idx_extdup_s:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.DSub V64:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) (SUBREG_TO_REG (i64 0), $vec, dsub), sve_elm_idx_extdup_s:$index)>;
}

foreach VT = [nxv2i64, nxv2f64] in {
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (VT ZPR:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) ZPR:$vec, sve_elm_idx_extdup_d:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.ZSub V128:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) (SUBREG_TO_REG (i64 0), $vec, zsub), sve_elm_idx_extdup_d:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (SVEType<VT>.DSub V64:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) (SUBREG_TO_REG (i64 0), $vec, dsub), sve_elm_idx_extdup_d:$index)>;
}

// When extracting from an unpacked vector the index must be scaled to account
// for the "holes" in the underlying packed vector type. We get the scaling
// for free by "promoting" the element type to one whose underlying vector
// type is packed. This is only valid when extracting from a vector whose
// length is the same or bigger than the result of the splat.

foreach VT = [nxv4f16, nxv4bf16] in {
def : Pat<(SVEType<VT>.HalfLength (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (VT ZPR:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) ZPR:$vec, sve_elm_idx_extdup_s:$index)>;
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (VT ZPR:$vec), sve_elm_idx_extdup_s:$index)))),
(!cast<Instruction>(NAME # _S) ZPR:$vec, sve_elm_idx_extdup_s:$index)>;
}

foreach VT = [nxv2f16, nxv2f32, nxv2bf16] in {
def : Pat<(VT (splat_vector (SVEType<VT>.EltAsScalar (vector_extract (VT ZPR:$vec), sve_elm_idx_extdup_d:$index)))),
(!cast<Instruction>(NAME # _D) ZPR:$vec, sve_elm_idx_extdup_d:$index)>;
}

// Duplicate an indexed 128-bit segment across a vector.

def : Pat<(nxv16i8 (AArch64duplane128 nxv16i8:$Op1, i64:$imm)),
(!cast<Instruction>(NAME # _Q) $Op1, $imm)>;
Expand Down
Loading
Loading