1
1
#pragma once
2
2
3
+ #include " cuda_utils.h"
3
4
#include " cutlass/cutlass.h"
4
5
#include " cutlass/numeric_types.h"
5
6
@@ -22,49 +23,49 @@ namespace vllm {
22
23
23
24
using namespace cute ;
24
25
25
- template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
26
- class ClusterShape , typename EpilogueScheduler,
27
- typename MainloopScheduler>
26
+ // clang-format off
27
+ template <class OutType , int ScaleGranularityM,
28
+ int ScaleGranularityN, int ScaleGranularityK,
29
+ class MmaTileShape , class ClusterShape ,
30
+ class EpilogueScheduler , class MainloopScheduler ,
31
+ bool swap_ab_ = false >
28
32
struct cutlass_3x_gemm_fp8_blockwise {
33
+ static constexpr bool swap_ab = swap_ab_;
29
34
using ElementAB = cutlass::float_e4m3_t ;
30
35
31
36
using ElementA = ElementAB;
32
37
using LayoutA = cutlass::layout::RowMajor;
38
+ using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
33
39
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
34
40
35
41
using ElementB = ElementAB;
36
42
using LayoutB = cutlass::layout::ColumnMajor;
43
+ using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
37
44
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
38
45
39
- using ElementC = void ;
40
46
using ElementD = OutType;
41
47
using LayoutD = cutlass::layout::RowMajor;
48
+ using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
42
49
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
43
50
51
+ using ElementC = void ; // TODO: support bias
44
52
using LayoutC = LayoutD;
53
+ using LayoutC_Transpose = LayoutD_Transpose;
45
54
static constexpr int AlignmentC = AlignmentD;
46
55
47
56
using ElementAccumulator = float ;
48
57
using ElementCompute = float ;
49
58
using ElementBlockScale = float ;
50
59
51
- // MMA and Cluster Tile Shapes
52
- // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
53
- // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
54
- static constexpr int ScaleMsPerTile = size<0 >(ScalesPerTile{});
55
- static constexpr int ScaleGranularityM =
56
- size<0 >(MmaTileShape{}) / ScaleMsPerTile;
57
- static constexpr int ScaleGranularityN =
58
- size<1 >(MmaTileShape{}) / size<1 >(ScalesPerTile{});
59
- static constexpr int ScaleGranularityK =
60
- size<2 >(MmaTileShape{}) / size<2 >(ScalesPerTile{});
61
-
62
- // Shape of the threadblocks in a cluster
63
- using ClusterShape_MNK = ClusterShape;
64
-
65
- using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
66
- ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
67
- cute::UMMA::Major::MN, cute::UMMA::Major::K>;
60
+ using ScaleConfig = conditional_t <swap_ab,
61
+ cutlass::detail::Sm100BlockwiseScaleConfig<
62
+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
63
+ cute::UMMA::Major::K, cute::UMMA::Major::MN>,
64
+ cutlass::detail::Sm100BlockwiseScaleConfig<
65
+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
66
+ cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
67
+
68
+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
68
69
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
69
70
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
70
71
@@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {
73
74
74
75
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
75
76
using ElementScalar = float ;
76
- // clang-format off
77
77
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
78
78
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
79
79
ArchTag,
@@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise {
84
84
ElementAccumulator,
85
85
ElementCompute,
86
86
ElementC,
87
- LayoutC,
87
+ conditional_t <swap_ab, LayoutC_Transpose, LayoutC> ,
88
88
AlignmentC,
89
89
ElementD,
90
- LayoutD,
90
+ conditional_t <swap_ab, LayoutD_Transpose, LayoutD> ,
91
91
AlignmentD,
92
92
EpilogueScheduler,
93
93
DefaultOperation
94
94
>::CollectiveOp;
95
95
96
96
using StageCountType = cutlass::gemm::collective::StageCountAuto;
97
- using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
98
- ArchTag,
99
- OperatorClass ,
100
- ElementA ,
101
- cute::tuple<LayoutA, LayoutSFA> ,
102
- AlignmentA ,
103
- ElementB ,
104
- cute::tuple<LayoutB, LayoutSFB> ,
105
- AlignmentB ,
106
- ElementAccumulator ,
107
- MmaTileShape ,
108
- ClusterShape ,
109
-
97
+ using CollectiveMainloop = conditional_t <swap_ab,
98
+ typename cutlass::gemm::collective::CollectiveBuilder<
99
+ ArchTag ,
100
+ OperatorClass ,
101
+ ElementB ,
102
+ cute::tuple<LayoutB_Transpose, LayoutSFA> ,
103
+ AlignmentB ,
104
+ ElementA ,
105
+ cute::tuple<LayoutA_Transpose, LayoutSFB> ,
106
+ AlignmentA ,
107
+ ElementAccumulator ,
108
+ MmaTileShape ,
109
+ ClusterShape,
110
110
cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(sizeof (typename CollectiveEpilogue::SharedStorage))>,
111
- MainloopScheduler
112
- >::CollectiveOp;
113
- // clang-format on
111
+ MainloopScheduler
112
+ >::CollectiveOp,
113
+ typename cutlass::gemm::collective::CollectiveBuilder<
114
+ ArchTag,
115
+ OperatorClass,
116
+ ElementA,
117
+ cute::tuple<LayoutA, LayoutSFA>,
118
+ AlignmentA,
119
+ ElementB,
120
+ cute::tuple<LayoutB, LayoutSFB>,
121
+ AlignmentB,
122
+ ElementAccumulator,
123
+ MmaTileShape,
124
+ ClusterShape,
125
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(sizeof (typename CollectiveEpilogue::SharedStorage))>,
126
+ MainloopScheduler
127
+ >::CollectiveOp>;
114
128
115
129
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
116
130
Shape<int , int , int , int >, CollectiveMainloop, CollectiveEpilogue>>;
@@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
123
137
torch::Tensor const & b,
124
138
torch::Tensor const & a_scales,
125
139
torch::Tensor const & b_scales) {
140
+ static constexpr bool swap_ab = Gemm::swap_ab;
126
141
using GemmKernel = typename Gemm::GemmKernel;
127
142
using StrideA = typename Gemm::GemmKernel::StrideA;
128
143
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
136
151
using ElementD = typename Gemm::ElementD;
137
152
138
153
int32_t m = a.size (0 ), n = b.size (1 ), k = a.size (1 );
139
- auto prob_shape = cute::make_shape (m, n, k, 1 );
140
154
141
155
StrideA a_stride;
142
156
StrideB b_stride;
@@ -146,21 +160,36 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
146
160
b_stride =
147
161
cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, 1 ));
148
162
c_stride =
149
- cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, 1 ));
163
+ cutlass::make_cute_packed_stride (StrideC{}, swap_ab ? cute::make_shape (n, m, 1 ) : cute::make_shape (m, n, 1 ));
150
164
151
- LayoutSFA layout_SFA =
165
+ LayoutSFA layout_SFA = swap_ab ?
166
+ ScaleConfig::tile_atom_to_shape_SFA (make_shape (n, m, k, 1 )) :
152
167
ScaleConfig::tile_atom_to_shape_SFA (make_shape (m, n, k, 1 ));
153
- LayoutSFB layout_SFB =
168
+ LayoutSFB layout_SFB = swap_ab ?
169
+ ScaleConfig::tile_atom_to_shape_SFB (make_shape (n, m, k, 1 )) :
154
170
ScaleConfig::tile_atom_to_shape_SFB (make_shape (m, n, k, 1 ));
155
171
156
172
auto a_ptr = static_cast <ElementAB*>(a.data_ptr ());
157
173
auto b_ptr = static_cast <ElementAB*>(b.data_ptr ());
158
174
auto a_scales_ptr = static_cast <float *>(a_scales.data_ptr ());
159
175
auto b_scales_ptr = static_cast <float *>(b_scales.data_ptr ());
160
176
161
- typename GemmKernel::MainloopArguments mainloop_args{
162
- a_ptr, a_stride, b_ptr, b_stride,
163
- a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
177
+ auto mainloop_args = [&](){
178
+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
179
+ if (swap_ab) {
180
+ return typename GemmKernel::MainloopArguments{
181
+ b_ptr, b_stride, a_ptr, a_stride,
182
+ b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
183
+ };
184
+ }
185
+ else {
186
+ return typename GemmKernel::MainloopArguments{
187
+ a_ptr, a_stride, b_ptr, b_stride,
188
+ a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
189
+ };
190
+ }
191
+ }();
192
+ auto prob_shape = swap_ab ? cute::make_shape (n, m, k, 1 ) : cute::make_shape (m, n, k, 1 );
164
193
165
194
auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
166
195
typename GemmKernel::EpilogueArguments epilogue_args{
@@ -181,23 +210,71 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
181
210
int sms;
182
211
cudaDeviceGetAttribute (&sms, cudaDevAttrMultiProcessorCount, a.get_device ());
183
212
184
- auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128 ) {
185
- return std::ceil (static_cast <float >(m) / tile1SM) *
186
- std::ceil (static_cast <float >(n) / tile1SM) >=
187
- sms;
188
- };
189
- bool use_2sm = should_use_2sm (m, n);
190
- if (use_2sm) {
191
- cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
192
- OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
193
- Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
194
- cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
195
- out, a, b, a_scales, b_scales);
213
+ constexpr int TILE_K = 128 ;
214
+ // TODO: better heuristics
215
+ bool swap_ab = (m < 16 ) || (m % 4 != 0 );
216
+ bool use_tma_epilogue = (m * n) % 4 == 0 ;
217
+ if (!swap_ab) {
218
+ constexpr int TILE_N = 128 ;
219
+ int tile_m = 256 ;
220
+ if (cuda_utils::ceil_div (n, TILE_N) * cuda_utils::ceil_div (m, 64 ) <= sms) {
221
+ tile_m = 64 ;
222
+ }
223
+ else if (cuda_utils::ceil_div (n, TILE_N) * cuda_utils::ceil_div (m, 128 ) <= sms) {
224
+ tile_m = 128 ;
225
+ }
226
+ if (tile_m == 64 ) {
227
+ if (use_tma_epilogue) {
228
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
229
+ OutType, 1 , TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
230
+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
231
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
232
+ out, a, b, a_scales, b_scales);
233
+ } else {
234
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
235
+ OutType, 1 , TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
236
+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
237
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
238
+ out, a, b, a_scales, b_scales);
239
+ }
240
+ } else if (tile_m == 128 ) {
241
+ if (use_tma_epilogue) {
242
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
243
+ OutType, 1 , TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
244
+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
245
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
246
+ out, a, b, a_scales, b_scales);
247
+ } else {
248
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
249
+ OutType, 1 , TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
250
+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
251
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
252
+ out, a, b, a_scales, b_scales);
253
+ }
254
+ } else { // tile_m == 256
255
+ if (use_tma_epilogue) {
256
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
257
+ OutType, 1 , TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
258
+ Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
259
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
260
+ out, a, b, a_scales, b_scales);
261
+ } else {
262
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
263
+ OutType, 1 , TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
264
+ Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
265
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
266
+ out, a, b, a_scales, b_scales);
267
+ }
268
+ }
196
269
} else {
270
+ // TODO: Test more tile N configs
271
+ constexpr int TILE_M = 128 ;
272
+ constexpr int TILE_N = 16 ;
273
+ // TMA epilogue isn't compatible with Swap A/B
197
274
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
198
- OutType, Shape<_128, _128, _128> , Shape<_128, _1, _1 >,
199
- Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm ,
200
- cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
275
+ OutType, TILE_M, 1 , TILE_K , Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K> >,
276
+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm ,
277
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true >>(
201
278
out, a, b, a_scales, b_scales);
202
279
}
203
280
}
0 commit comments