Skip to content

Support fusion moe #10507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged

Conversation

risemeup1
Copy link
Contributor

Before submitting

  • Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

PR changes

Description

Copy link

paddle-bot bot commented Apr 27, 2025

Thanks for your contribution!

@risemeup1 risemeup1 changed the base branch from develop to dsv3_dev April 27, 2025 02:27
Copy link

@A-nnonymous A-nnonymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in tokens_unzip_and_zip.cu

out=None,
accumulate=False,
use_split_accumulator=True,
is_a_1d_scaled=is_a_1d_scaled,
is_b_1d_scaled=is_b_1d_scaled,
)
else:
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.bfloat16)
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.float32)
def kitchen_fp8_gemm(x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, rtn_dtype=paddle.bfloat16):
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0:
y = kitchen.ops.fp8_gemm_blockwise(
a=x_fp8,
a_decode_scale=x_scale,
b=w_fp8,
b_decode_scale=w_scale,
out_dtype=rtn_dtype,
out=None,
accumulate=False,
use_split_accumulator=True,
is_a_1d_scaled=is_a_1d_scaled,
is_b_1d_scaled=is_b_1d_scaled,
)
else:
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype)
return y

o1 = paddle.stack([out_0, out_1, out_2, out_3])

return o1
o1 = paddle.zeros([expert_w_count, x_fp8.shape[1], w1_t_quant.shape[1]], dtype="bfloat16")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype 建议使用 x_bf16.dtype 的形式

else:
_, seq_len, H1 = o2_quant.shape
_, H2, _ = w2_quant.shape
o3 = paddle.zeros([expert_w_count, o2_quant.shape[1], w2_quant.shape[1]], dtype=paddle.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

)
unzipped_grad_fp8 = unzipped_grad_fp8.reshape([len(expert_w2), -1, unzipped_grad_fp8.shape[-1]])
unzipped_grad_scale = unzipped_grad_scale.reshape([len(expert_w2), -1, unzipped_grad_scale.shape[-1]])
do2_s = paddle.zeros([len(expert_w2), unzipped_grad_fp8.shape[1], bw_w2_quant.shape[1]], dtype="bfloat16")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


# ===== do1 = swiglu_grad(o1, None, do2) =====
def bwd_swiglu(self, o1, do2):
do1, _ = paddle._C_ops.swiglu_grad(self.o1, None, do2)
def bwd_swiglu(self, o1, do2, tokens_per_expert):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除没用的参数

input_x_fp8 = input_x_fp8.reshape([group_num, H1, -1])
input_x_scale = input_x_scale.reshape([group_num, H1, -1])
input_x_fp8 = paddle.split(input_x_fp8, num_or_sections=group_num, axis=0)
input_x_scale = paddle.split(input_x_scale, num_or_sections=group_num, axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis ?


# transpose do1 and quant do1
H2 = do1.shape[-1]
do1 = do1.reshape([group_num, -1, H2]).transpose([0, 2, 1]).contiguous().reshape([group_num * H2, -1])
do1_fp8, do1_scale = kitchen_quant(
do1, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
do1_fp8 = do1_fp8.reshape([group_num, H2, -1])
# do1_scale = do1_scale.T.contiguous().reshape([group_num, H2, -1])
do1_fp8 = paddle.split(do1_fp8, num_or_sections=group_num, axis=0)
do1_scale = paddle.split(do1_scale, num_or_sections=group_num, axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为后边要给kitchen,gemm用,backhend用的CUBLAS,do1_fp8 shape是shape=[8192, 8448],do1_scale 的shape是[66, 8192],所以一个axis=0,一个axis=-1


@paddle.no_grad()
def forward(
self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts
):
self.expert_out = expert_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete expert out unziooed_probs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def bwd_dowm_input(self, expert_w2, unzipped_grad, tokens_per_expert, expected_m):
# recomput o2
o2 = self.fwd_swiglu(self.o1)
o2_s = (o2 * self.unzipped_probs).cast(paddle.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方 是不是要补充一组实验,不cast会bf16, 这个输入,是给quant用的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最开始的版本就是不cast回bf16,如果不cast回bf16,loss diff比现在的大


input_x_fp8, input_x_scale = kitchen_quant(
input_x, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
input_x_fp8 = input_x_fp8.reshape([group_num, H1, -1])
input_x_scale = input_x_scale.reshape([group_num, H1, -1])
input_x_fp8 = paddle.split(input_x_fp8, num_or_sections=group_num, axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于 0号轴的split,感觉是可以reshape, 然后用stride机制,给下面的op来用的,来提升性能
input_x_fp8 = input_x_fp8.reshape([group_num, -1, input_x_fp8.shape[-1] ]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的


# transpose do1 and quant do1
H2 = do1.shape[-1]
do1 = do1.reshape([group_num, -1, H2]).transpose([0, 2, 1]).contiguous().reshape([group_num * H2, -1])
do1_fp8, do1_scale = kitchen_quant(
do1, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
do1_fp8 = do1_fp8.reshape([group_num, H2, -1])
# do1_scale = do1_scale.T.contiguous().reshape([group_num, H2, -1])
do1_fp8 = paddle.split(do1_fp8, num_or_sections=group_num, axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@phlrain phlrain merged commit 67b21ae into PaddlePaddle:dsv3_dev May 22, 2025
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants