Skip to content

Commit 59d5e36

Browse files
committed
swap A/B and naive heuristic
Signed-off-by: Lain <[email protected]>
1 parent 54631f8 commit 59d5e36

File tree

3 files changed

+139
-80
lines changed

3 files changed

+139
-80
lines changed

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
99
torch::Tensor const& b,
1010
torch::Tensor const& a_scales,
1111
torch::Tensor const& b_scales) {
12-
TORCH_CHECK(
13-
a.size(0) % 4 == 0,
14-
"Input tensor must have a number of rows that is a multiple of 4. ",
15-
"but got: ", a.size(0), " rows.");
1612
if (out.dtype() == torch::kBFloat16) {
1713
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
1814
out, a, b, a_scales, b_scales);

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh

Lines changed: 139 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "cuda_utils.h"
34
#include "cutlass/cutlass.h"
45
#include "cutlass/numeric_types.h"
56

@@ -22,49 +23,49 @@ namespace vllm {
2223

2324
using namespace cute;
2425

25-
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
26-
class ClusterShape, typename EpilogueScheduler,
27-
typename MainloopScheduler>
26+
// clang-format off
27+
template <class OutType, int ScaleGranularityM,
28+
int ScaleGranularityN, int ScaleGranularityK,
29+
class MmaTileShape, class ClusterShape,
30+
class EpilogueScheduler, class MainloopScheduler,
31+
bool swap_ab_ = false>
2832
struct cutlass_3x_gemm_fp8_blockwise {
33+
static constexpr bool swap_ab = swap_ab_;
2934
using ElementAB = cutlass::float_e4m3_t;
3035

3136
using ElementA = ElementAB;
3237
using LayoutA = cutlass::layout::RowMajor;
38+
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
3339
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
3440

3541
using ElementB = ElementAB;
3642
using LayoutB = cutlass::layout::ColumnMajor;
43+
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
3744
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
3845

39-
using ElementC = void;
4046
using ElementD = OutType;
4147
using LayoutD = cutlass::layout::RowMajor;
48+
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
4249
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
4350

51+
using ElementC = void; // TODO: support bias
4452
using LayoutC = LayoutD;
53+
using LayoutC_Transpose = LayoutD_Transpose;
4554
static constexpr int AlignmentC = AlignmentD;
4655

4756
using ElementAccumulator = float;
4857
using ElementCompute = float;
4958
using ElementBlockScale = float;
5059

51-
// MMA and Cluster Tile Shapes
52-
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
53-
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
54-
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
55-
static constexpr int ScaleGranularityM =
56-
size<0>(MmaTileShape{}) / ScaleMsPerTile;
57-
static constexpr int ScaleGranularityN =
58-
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
59-
static constexpr int ScaleGranularityK =
60-
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
61-
62-
// Shape of the threadblocks in a cluster
63-
using ClusterShape_MNK = ClusterShape;
64-
65-
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
66-
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
67-
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
60+
using ScaleConfig = conditional_t<swap_ab,
61+
cutlass::detail::Sm100BlockwiseScaleConfig<
62+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
63+
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
64+
cutlass::detail::Sm100BlockwiseScaleConfig<
65+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
66+
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
67+
68+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
6869
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
6970
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
7071

@@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {
7374

7475
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
7576
using ElementScalar = float;
76-
// clang-format off
7777
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
7878
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
7979
ArchTag,
@@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise {
8484
ElementAccumulator,
8585
ElementCompute,
8686
ElementC,
87-
LayoutC,
87+
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
8888
AlignmentC,
8989
ElementD,
90-
LayoutD,
90+
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
9191
AlignmentD,
9292
EpilogueScheduler,
9393
DefaultOperation
9494
>::CollectiveOp;
9595

9696
using StageCountType = cutlass::gemm::collective::StageCountAuto;
97-
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
98-
ArchTag,
99-
OperatorClass,
100-
ElementA,
101-
cute::tuple<LayoutA, LayoutSFA>,
102-
AlignmentA,
103-
ElementB,
104-
cute::tuple<LayoutB, LayoutSFB>,
105-
AlignmentB,
106-
ElementAccumulator,
107-
MmaTileShape,
108-
ClusterShape,
109-
97+
using CollectiveMainloop = conditional_t<swap_ab,
98+
typename cutlass::gemm::collective::CollectiveBuilder<
99+
ArchTag,
100+
OperatorClass,
101+
ElementB,
102+
cute::tuple<LayoutB_Transpose, LayoutSFA>,
103+
AlignmentB,
104+
ElementA,
105+
cute::tuple<LayoutA_Transpose, LayoutSFB>,
106+
AlignmentA,
107+
ElementAccumulator,
108+
MmaTileShape,
109+
ClusterShape,
110110
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
111-
MainloopScheduler
112-
>::CollectiveOp;
113-
// clang-format on
111+
MainloopScheduler
112+
>::CollectiveOp,
113+
typename cutlass::gemm::collective::CollectiveBuilder<
114+
ArchTag,
115+
OperatorClass,
116+
ElementA,
117+
cute::tuple<LayoutA, LayoutSFA>,
118+
AlignmentA,
119+
ElementB,
120+
cute::tuple<LayoutB, LayoutSFB>,
121+
AlignmentB,
122+
ElementAccumulator,
123+
MmaTileShape,
124+
ClusterShape,
125+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
126+
MainloopScheduler
127+
>::CollectiveOp>;
114128

115129
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
116130
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
@@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
123137
torch::Tensor const& b,
124138
torch::Tensor const& a_scales,
125139
torch::Tensor const& b_scales) {
140+
static constexpr bool swap_ab = Gemm::swap_ab;
126141
using GemmKernel = typename Gemm::GemmKernel;
127142
using StrideA = typename Gemm::GemmKernel::StrideA;
128143
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
136151
using ElementD = typename Gemm::ElementD;
137152

138153
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
139-
auto prob_shape = cute::make_shape(m, n, k, 1);
140154

141155
StrideA a_stride;
142156
StrideB b_stride;
@@ -146,21 +160,36 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
146160
b_stride =
147161
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
148162
c_stride =
149-
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
163+
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
150164

151-
LayoutSFA layout_SFA =
165+
LayoutSFA layout_SFA = swap_ab ?
166+
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
152167
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
153-
LayoutSFB layout_SFB =
168+
LayoutSFB layout_SFB = swap_ab ?
169+
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
154170
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
155171

156172
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
157173
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
158174
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
159175
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
160176

161-
typename GemmKernel::MainloopArguments mainloop_args{
162-
a_ptr, a_stride, b_ptr, b_stride,
163-
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
177+
auto mainloop_args = [&](){
178+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
179+
if (swap_ab) {
180+
return typename GemmKernel::MainloopArguments{
181+
b_ptr, b_stride, a_ptr, a_stride,
182+
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
183+
};
184+
}
185+
else {
186+
return typename GemmKernel::MainloopArguments{
187+
a_ptr, a_stride, b_ptr, b_stride,
188+
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
189+
};
190+
}
191+
}();
192+
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
164193

165194
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
166195
typename GemmKernel::EpilogueArguments epilogue_args{
@@ -181,23 +210,71 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
181210
int sms;
182211
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
183212

184-
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
185-
return std::ceil(static_cast<float>(m) / tile1SM) *
186-
std::ceil(static_cast<float>(n) / tile1SM) >=
187-
sms;
188-
};
189-
bool use_2sm = should_use_2sm(m, n);
190-
if (use_2sm) {
191-
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
192-
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
193-
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
194-
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
195-
out, a, b, a_scales, b_scales);
213+
constexpr int TILE_K = 128;
214+
// TODO: better heuristics
215+
bool swap_ab = (m < 16) || (m % 4 != 0);
216+
bool use_tma_epilogue = (m * n) % 4 == 0;
217+
if (!swap_ab) {
218+
constexpr int TILE_N = 128;
219+
int tile_m = 256;
220+
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
221+
tile_m = 64;
222+
}
223+
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
224+
tile_m = 128;
225+
}
226+
if (tile_m == 64) {
227+
if (use_tma_epilogue) {
228+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
229+
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
230+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
231+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
232+
out, a, b, a_scales, b_scales);
233+
} else {
234+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
235+
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
236+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
237+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
238+
out, a, b, a_scales, b_scales);
239+
}
240+
} else if (tile_m == 128) {
241+
if (use_tma_epilogue) {
242+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
243+
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
244+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
245+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
246+
out, a, b, a_scales, b_scales);
247+
} else {
248+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
249+
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
250+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
251+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
252+
out, a, b, a_scales, b_scales);
253+
}
254+
} else { // tile_m == 256
255+
if (use_tma_epilogue) {
256+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
257+
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
258+
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
259+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
260+
out, a, b, a_scales, b_scales);
261+
} else {
262+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
263+
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
264+
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
265+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
266+
out, a, b, a_scales, b_scales);
267+
}
268+
}
196269
} else {
270+
// TODO: Test more tile N configs
271+
constexpr int TILE_M = 128;
272+
constexpr int TILE_N = 16;
273+
// TMA epilogue isn't compatible with Swap A/B
197274
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
198-
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
199-
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
200-
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
275+
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
276+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
277+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
201278
out, a, b, a_scales, b_scales);
202279
}
203280
}

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,24 +135,10 @@ def ceil_div(x: int, y: int) -> int:
135135
use_cutlass, use_aiter_and_is_supported)
136136

137137
if use_cutlass:
138-
rows, cols = input_2d.shape
139-
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
140-
# optimal tensor core usage. Can be removed when targeting platforms
141-
# without this constraint.
142-
should_pad = current_platform.has_device_capability(
143-
100) and rows % 4 != 0
144-
if should_pad:
145-
input_2d = torch.nn.functional.pad(input_2d,
146-
(0, 0, 0, 4 - (rows % 4)),
147-
value=0).contiguous()
148-
149138
q_input, x_scale = per_token_group_quant_fp8(
150139
input_2d, block_size[1], column_major_scales=use_cutlass)
151-
152140
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
153141
block_size, input.dtype)
154-
if should_pad:
155-
output = output[:rows, :]
156142

157143
else:
158144
q_input, x_scale = per_token_group_quant_fp8(

0 commit comments

Comments
 (0)