-
Notifications
You must be signed in to change notification settings - Fork 3k
Support fusion moe #10507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fusion moe #10507
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in tokens_unzip_and_zip.cu
out=None, | ||
accumulate=False, | ||
use_split_accumulator=True, | ||
is_a_1d_scaled=is_a_1d_scaled, | ||
is_b_1d_scaled=is_b_1d_scaled, | ||
) | ||
else: | ||
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.bfloat16) | ||
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.float32) | |
def kitchen_fp8_gemm(x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, rtn_dtype=paddle.bfloat16): | |
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: | |
y = kitchen.ops.fp8_gemm_blockwise( | |
a=x_fp8, | |
a_decode_scale=x_scale, | |
b=w_fp8, | |
b_decode_scale=w_scale, | |
out_dtype=rtn_dtype, | |
out=None, | |
accumulate=False, | |
use_split_accumulator=True, | |
is_a_1d_scaled=is_a_1d_scaled, | |
is_b_1d_scaled=is_b_1d_scaled, | |
) | |
else: | |
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) | |
return y | |
paddlenlp/transformers/fp8_utils.py
Outdated
o1 = paddle.stack([out_0, out_1, out_2, out_3]) | ||
|
||
return o1 | ||
o1 = paddle.zeros([expert_w_count, x_fp8.shape[1], w1_t_quant.shape[1]], dtype="bfloat16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype 建议使用 x_bf16.dtype 的形式
paddlenlp/transformers/fp8_utils.py
Outdated
else: | ||
_, seq_len, H1 = o2_quant.shape | ||
_, H2, _ = w2_quant.shape | ||
o3 = paddle.zeros([expert_w_count, o2_quant.shape[1], w2_quant.shape[1]], dtype=paddle.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
paddlenlp/transformers/fp8_utils.py
Outdated
) | ||
unzipped_grad_fp8 = unzipped_grad_fp8.reshape([len(expert_w2), -1, unzipped_grad_fp8.shape[-1]]) | ||
unzipped_grad_scale = unzipped_grad_scale.reshape([len(expert_w2), -1, unzipped_grad_scale.shape[-1]]) | ||
do2_s = paddle.zeros([len(expert_w2), unzipped_grad_fp8.shape[1], bw_w2_quant.shape[1]], dtype="bfloat16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
paddlenlp/transformers/fp8_utils.py
Outdated
|
||
# ===== do1 = swiglu_grad(o1, None, do2) ===== | ||
def bwd_swiglu(self, o1, do2): | ||
do1, _ = paddle._C_ops.swiglu_grad(self.o1, None, do2) | ||
def bwd_swiglu(self, o1, do2, tokens_per_expert): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除没用的参数
input_x_fp8 = input_x_fp8.reshape([group_num, H1, -1]) | ||
input_x_scale = input_x_scale.reshape([group_num, H1, -1]) | ||
input_x_fp8 = paddle.split(input_x_fp8, num_or_sections=group_num, axis=0) | ||
input_x_scale = paddle.split(input_x_scale, num_or_sections=group_num, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis ?
|
||
# transpose do1 and quant do1 | ||
H2 = do1.shape[-1] | ||
do1 = do1.reshape([group_num, -1, H2]).transpose([0, 2, 1]).contiguous().reshape([group_num * H2, -1]) | ||
do1_fp8, do1_scale = kitchen_quant( | ||
do1, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False | ||
) | ||
do1_fp8 = do1_fp8.reshape([group_num, H2, -1]) | ||
# do1_scale = do1_scale.T.contiguous().reshape([group_num, H2, -1]) | ||
do1_fp8 = paddle.split(do1_fp8, num_or_sections=group_num, axis=0) | ||
do1_scale = paddle.split(do1_scale, num_or_sections=group_num, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为后边要给kitchen,gemm用,backhend用的CUBLAS,do1_fp8 shape是shape=[8192, 8448],do1_scale 的shape是[66, 8192],所以一个axis=0,一个axis=-1
paddlenlp/transformers/moe_utils.py
Outdated
|
||
@paddle.no_grad() | ||
def forward( | ||
self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts | ||
): | ||
self.expert_out = expert_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete expert out unziooed_probs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def bwd_dowm_input(self, expert_w2, unzipped_grad, tokens_per_expert, expected_m): | ||
# recomput o2 | ||
o2 = self.fwd_swiglu(self.o1) | ||
o2_s = (o2 * self.unzipped_probs).cast(paddle.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方 是不是要补充一组实验,不cast会bf16, 这个输入,是给quant用的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最开始的版本就是不cast回bf16,如果不cast回bf16,loss diff比现在的大
paddlenlp/transformers/fp8_utils.py
Outdated
|
||
input_x_fp8, input_x_scale = kitchen_quant( | ||
input_x, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False | ||
) | ||
input_x_fp8 = input_x_fp8.reshape([group_num, H1, -1]) | ||
input_x_scale = input_x_scale.reshape([group_num, H1, -1]) | ||
input_x_fp8 = paddle.split(input_x_fp8, num_or_sections=group_num, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于 0号轴的split,感觉是可以reshape, 然后用stride机制,给下面的op来用的,来提升性能
input_x_fp8 = input_x_fp8.reshape([group_num, -1, input_x_fp8.shape[-1] ]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
paddlenlp/transformers/fp8_utils.py
Outdated
|
||
# transpose do1 and quant do1 | ||
H2 = do1.shape[-1] | ||
do1 = do1.reshape([group_num, -1, H2]).transpose([0, 2, 1]).contiguous().reshape([group_num * H2, -1]) | ||
do1_fp8, do1_scale = kitchen_quant( | ||
do1, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False | ||
) | ||
do1_fp8 = do1_fp8.reshape([group_num, H2, -1]) | ||
# do1_scale = do1_scale.T.contiguous().reshape([group_num, H2, -1]) | ||
do1_fp8 = paddle.split(do1_fp8, num_or_sections=group_num, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
Before submitting
tests
folder. If there are codecov issues, please add tests cases first.PR types
PR changes
Description