@@ -100,19 +100,22 @@ void add_q_8w_linear_node(
100
100
101
101
std::vector<PushConstantDataInfo> pcs;
102
102
if (graph.is_buffer_storage (out_W_packed)) {
103
- pcs = {graph.sizes_pc_of (out_W_packed),
104
- graph.strides_pc_of (out_W_packed),
105
- graph.sizes_pc_of (mat1_W_packed),
106
- graph.strides_pc_of (mat1),
107
- graph.strides_pc_of (q_mat2),
108
- graph.strides_pc_of (scales),
109
- graph.numel_pc_of (out_W_packed)};
103
+ pcs = {
104
+ graph.sizes_pc_of (out_W_packed),
105
+ graph.strides_pc_of (out_W_packed),
106
+ graph.sizes_pc_of (mat1_W_packed),
107
+ graph.strides_pc_of (mat1),
108
+ graph.strides_pc_of (q_mat2),
109
+ graph.strides_pc_of (scales),
110
+ graph.numel_pc_of (out_W_packed)};
110
111
} else {
111
- pcs = {graph.logical_limits_pc_of (out_W_packed),
112
- graph.sizes_pc_of (mat1_W_packed)};
112
+ pcs = {
113
+ graph.logical_limits_pc_of (out_W_packed),
114
+ graph.sizes_pc_of (mat1_W_packed)};
113
115
}
114
116
115
- const utils::uvec3 global_wg = {static_cast <uint32_t >(graph.numel_of (out_W_packed)), 1 , 1 };
117
+ const utils::uvec3 global_wg = {
118
+ static_cast <uint32_t >(graph.numel_of (out_W_packed)), 1 , 1 };
116
119
const utils::uvec3 local_wg{64 , 1 , 1 };
117
120
118
121
graph.execute_nodes ().emplace_back (new DispatchNode (
0 commit comments