Skip to content

[ET-VK] Use performant tiled algorithm for 4 bit weight only quantized linear #10236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ vkapi::VulkanImage allocate_image(
return vkapi::VulkanImage();
}

// TODO(ssjia): change to always check that the image extents do not exceed
// physical limits. Adding the check now based on `maxImageDimension3D` will
// cause some existing models to break. Anecdotally, on Adreno and
// SwiftShader devices, using 3D textures that exceed `maxImageDimension3D`
// appears to be ok. So we need to figure out if is it undefined behaviour
// or if there's a better way to figure out what the limit is. For now, only
// check during debug build so that we can detect when exceeding physical
// limits could be a potential cause for model outputs to be wrong. In the
// meantime, the threshold for using texture storage can be configured at
// export time.
#ifdef VULKAN_DEBUG
uint32_t max_extent = storage_type == utils::kTexture3D
? adapter_ptr->max_texture3d_dim()
: adapter_ptr->max_texture2d_dim();

VK_CHECK_COND(
image_extents[0] <= max_extent && image_extents[1] <= max_extent &&
image_extents[2] <= max_extent);
#endif

VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props);

return adapter_ptr->vma().create_image(
Expand Down Expand Up @@ -291,6 +311,8 @@ vkapi::VulkanBuffer allocate_buffer(
return vkapi::VulkanBuffer();
}

VK_CHECK_COND(numel <= context_ptr->adapter_ptr()->max_buffer_numel());

return adapter_ptr->vma().create_storage_buffer(
element_size(dtype) * numel, allocate_memory);
}
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def buffer_gvec_type(dtype: str, n: int) -> str:

if dtype == "float":
return f"vec{n}"
if dtype == "uint":
return f"uvec{n}"
elif dtype == "half":
return f"f16vec{n}"
elif dtype == "int":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,52 @@

#define PRECISION ${PRECISION}

${define_required_extensions("uint8")}
${define_required_extensions("int8")}
$if not NO_INT8_BUFFERS:
${define_required_extensions("uint8")}
$if STORAGE == "buffer":
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
$if NO_INT8_BUFFERS:
${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")}
$else:
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}

layout(push_constant) uniform restrict Block {
ivec4 qmat2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

uint8_t get_first(const uint8_t packed) {
return uint8_t((packed & 0xF0) >> 4);
$if NO_INT8_BUFFERS:
#define BUF_T uint
$else:
#define BUF_T uint8_t

$if STORAGE == "buffer":
#define UVEC4_T u8vec4
$else:
#define UVEC4_T uvec4

uint get_first(const BUF_T packed) {
return (packed & 0xF0) >> 4;
}

uint8_t get_second(const uint8_t packed) {
return uint8_t(packed & 0x0F);
uint get_second(const BUF_T packed) {
return packed & 0x0F;
}

uint8_t combine(const uint8_t first, const uint8_t second) {
return uint8_t(first << 4 | second);
uint combine(const uint first, const uint second) {
return (first << 4 | second);
}

$if NO_INT8_BUFFERS:
uint extract_comp(const uint packed4, const uint idx) {
return (packed4 >> (idx * 8)) & 0xFF;
}

/*
* This shader packs the weight tensor into a texture.
*
Expand Down Expand Up @@ -102,25 +122,32 @@ void main() {
int in_numcols = qmat2_sizes.y;
int in_num_int8_cols = qmat2_sizes.y >> 1;

uint8_t in_vals[8][2];
uint in_vals[8][2];
for (int r = 0; r < 8; ++r) {
if (in_row + r < in_numrows) {
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
uint scalar_idx = (in_row + r) * in_num_int8_cols + in_int8_col;
$if NO_INT8_BUFFERS:
BUF_T in_val_packed_texel = nchw_4x2[scalar_idx >> 2];
const uint packed_idx = scalar_idx % 4;
uint in_val_packed = extract_comp(in_val_packed_texel, packed_idx);
$else:
BUF_T in_val_packed = nchw_4x2[scalar_idx];

in_vals[r][0] = get_first(in_val_packed);
in_vals[r][1] = get_second(in_val_packed);
} else {
in_vals[r][0] = uint8_t(254);
in_vals[r][1] = uint8_t(254);
in_vals[r][0] = uint(0);
in_vals[r][1] = uint(0);
}
}

u8vec4 out_tex_1 = u8vec4(
UVEC4_T out_tex_1 = UVEC4_T(
combine(in_vals[0][0], in_vals[4][0]),
combine(in_vals[1][0], in_vals[5][0]),
combine(in_vals[2][0], in_vals[6][0]),
combine(in_vals[3][0], in_vals[7][0]));

u8vec4 out_tex_2 = u8vec4(
UVEC4_T out_tex_2 = UVEC4_T(
combine(in_vals[0][1], in_vals[4][1]),
combine(in_vals[1][1], in_vals[5][1]),
combine(in_vals[2][1], in_vals[6][1]),
Expand All @@ -131,6 +158,6 @@ void main() {
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
$else:
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
imageStore(t_qmat2, packed_pos.xy, out_tex_1);
imageStore(t_qmat2, ivec2(packed_pos.x, packed_pos.y + 1), out_tex_2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

pack_int4_linear_weight_transposed_interleaved:
parameter_names_with_default_values:
STORAGE: texture3d
STORAGE: texture2d
NO_INT8_BUFFERS: false
shader_variants:
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
- NAME: pack_int4_linear_weight_transposed_interleaved_texture2d
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
STORAGE: buffer
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d
NO_INT8_BUFFERS: true
122 changes: 0 additions & 122 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

This file was deleted.

Loading
Loading