Closed
Description
This code: (Godbolt link)
export fn foo(chunk: @Vector(32, u8)) @TypeOf(chunk) {
return @as(@TypeOf(chunk), @splat(1)) << @truncate(chunk);
}
define dso_local range(i8 1, -127) <32 x i8> @foo(<32 x i8> %0) local_unnamed_addr {
Entry:
%1 = and <32 x i8> %0, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
%2 = shl nuw <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, %1
ret <32 x i8> %2
}
Compiles like so for Zen 3:
.LCPI0_1:
.zero 32,16
.LCPI0_2:
.zero 32,252
.LCPI0_3:
.zero 32,224
.LCPI0_4:
.byte 1
foo:
vpsllw ymm0, ymm0, 5
vpbroadcastb ymm1, byte ptr [rip + .LCPI0_4]
vpblendvb ymm1, ymm1, ymmword ptr [rip + .LCPI0_1], ymm0
vpand ymm0, ymm0, ymmword ptr [rip + .LCPI0_3]
vpsllw ymm2, ymm1, 2
vpand ymm2, ymm2, ymmword ptr [rip + .LCPI0_2]
vpaddb ymm0, ymm0, ymm0
vpblendvb ymm1, ymm1, ymm2, ymm0
vpaddb ymm0, ymm0, ymm0
vpaddb ymm2, ymm1, ymm1
vpblendvb ymm0, ymm1, ymm2, ymm0
ret
However, because the bytes resulting from @truncate(chunk)
are in the range [0, 7], we can precompute all 8 possible answers and use vpshufb instead (Godbolt, full code):
export fn foo2(chunk: @Vector(32, u8)) @TypeOf(chunk) {
const table = comptime foo(std.simd.repeat(@sizeOf(@TypeOf(chunk)), std.simd.iota(u8, 16)));
return vpshufb(table, @as(@Vector(32, u3), @truncate(chunk)));
}
fn vpshufb(table: anytype, indices: @TypeOf(table)) @TypeOf(table) {
if (@inComptime()) {
var result: @TypeOf(indices) = undefined;
for (0..@bitSizeOf(@TypeOf(indices)) / 8) |i| {
const index = indices[i];
result[i] = if (index >= 0x80) 0 else table[index % (@bitSizeOf(@TypeOf(table)) / 8)];
}
return result;
}
const methods = struct {
extern fn @"llvm.x86.avx512.pshuf.b.512"(@Vector(64, u8), @Vector(64, u8)) @Vector(64, u8);
extern fn @"llvm.x86.avx2.pshuf.b"(@Vector(32, u8), @Vector(32, u8)) @Vector(32, u8);
extern fn @"llvm.x86.ssse3.pshuf.b.128"(@Vector(16, u8), @Vector(16, u8)) @Vector(16, u8);
};
return switch (@TypeOf(table)) {
@Vector(64, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx512bw)) methods.@"llvm.x86.avx512.pshuf.b.512"(table, indices) else @compileError("CPU target lacks support for vpshufb512"),
@Vector(32, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) methods.@"llvm.x86.avx2.pshuf.b"(table, indices) else @compileError("CPU target lacks support for vpshufb256"),
@Vector(16, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) methods.@"llvm.x86.ssse3.pshuf.b.128"(table, indices) else @compileError("CPU target lacks support for vpshufb128"),
else => @compileError(std.fmt.comptimePrint("Invalid argument type passed to vpshufb: {}\n", .{@TypeOf(table)})),
};
}
.LCPI0_0:
.zero 32,7
# Removed dead vector data. See https://github.com/llvm/llvm-project/issues/110305
.LCPI0_2:
.byte 1
.byte 2
.byte 4
.byte 8
.byte 16
.byte 32
.byte 64
.byte 128
foo2:
vpand ymm0, ymm0, ymmword ptr [rip + .LCPI0_0]
vpbroadcastq ymm1, qword ptr [rip + .LCPI0_2]
vpshufb ymm0, ymm1, ymm0
ret