Skip to content

[AVX2+] Vectorized 1 << u3 in a byte vector should turn into vpshufb #110317

Closed
@Validark

Description

@Validark

This code: (Godbolt link)

export fn foo(chunk: @Vector(32, u8)) @TypeOf(chunk) {
    return @as(@TypeOf(chunk), @splat(1)) << @truncate(chunk);
}
define dso_local range(i8 1, -127) <32 x i8> @foo(<32 x i8> %0) local_unnamed_addr {
Entry:
  %1 = and <32 x i8> %0, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
  %2 = shl nuw <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, %1
  ret <32 x i8> %2
}

Compiles like so for Zen 3:

.LCPI0_1:
        .zero   32,16
.LCPI0_2:
        .zero   32,252
.LCPI0_3:
        .zero   32,224
.LCPI0_4:
        .byte   1
foo:
        vpsllw  ymm0, ymm0, 5
        vpbroadcastb    ymm1, byte ptr [rip + .LCPI0_4]
        vpblendvb       ymm1, ymm1, ymmword ptr [rip + .LCPI0_1], ymm0
        vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI0_3]
        vpsllw  ymm2, ymm1, 2
        vpand   ymm2, ymm2, ymmword ptr [rip + .LCPI0_2]
        vpaddb  ymm0, ymm0, ymm0
        vpblendvb       ymm1, ymm1, ymm2, ymm0
        vpaddb  ymm0, ymm0, ymm0
        vpaddb  ymm2, ymm1, ymm1
        vpblendvb       ymm0, ymm1, ymm2, ymm0
        ret

However, because the bytes resulting from @truncate(chunk) are in the range [0, 7], we can precompute all 8 possible answers and use vpshufb instead (Godbolt, full code):

export fn foo2(chunk: @Vector(32, u8)) @TypeOf(chunk) {
    const table = comptime foo(std.simd.repeat(@sizeOf(@TypeOf(chunk)), std.simd.iota(u8, 16)));
    return vpshufb(table, @as(@Vector(32, u3), @truncate(chunk)));
}

fn vpshufb(table: anytype, indices: @TypeOf(table)) @TypeOf(table) {
    if (@inComptime()) {
        var result: @TypeOf(indices) = undefined;
        for (0..@bitSizeOf(@TypeOf(indices)) / 8) |i| {
            const index = indices[i];
            result[i] = if (index >= 0x80) 0 else table[index % (@bitSizeOf(@TypeOf(table)) / 8)];
        }

        return result;
    }

    const methods = struct {
        extern fn @"llvm.x86.avx512.pshuf.b.512"(@Vector(64, u8), @Vector(64, u8)) @Vector(64, u8);
        extern fn @"llvm.x86.avx2.pshuf.b"(@Vector(32, u8), @Vector(32, u8)) @Vector(32, u8);
        extern fn @"llvm.x86.ssse3.pshuf.b.128"(@Vector(16, u8), @Vector(16, u8)) @Vector(16, u8);
    };

    return switch (@TypeOf(table)) {
        @Vector(64, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx512bw)) methods.@"llvm.x86.avx512.pshuf.b.512"(table, indices) else @compileError("CPU target lacks support for vpshufb512"),
        @Vector(32, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) methods.@"llvm.x86.avx2.pshuf.b"(table, indices) else @compileError("CPU target lacks support for vpshufb256"),
        @Vector(16, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) methods.@"llvm.x86.ssse3.pshuf.b.128"(table, indices) else @compileError("CPU target lacks support for vpshufb128"),
        else => @compileError(std.fmt.comptimePrint("Invalid argument type passed to vpshufb: {}\n", .{@TypeOf(table)})),
    };
}
.LCPI0_0:
        .zero   32,7
# Removed dead vector data. See https://github.com/llvm/llvm-project/issues/110305
.LCPI0_2:
        .byte   1
        .byte   2
        .byte   4
        .byte   8
        .byte   16
        .byte   32
        .byte   64
        .byte   128
foo2:
        vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI0_0]
        vpbroadcastq    ymm1, qword ptr [rip + .LCPI0_2]
        vpshufb ymm0, ymm1, ymm0
        ret

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions