gridwise_moe_gemm_blockscale.hpp Source File

gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
18
19#define DEBUG_LOG 0
20
21namespace ck {
22
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29
31{
32 gelu_and_mul = 0,
33 silu_and_mul = 1
34};
35
36template <typename GridwiseGemm,
37 bool HasMainKBlockLoop,
38 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39 index_t MinimumOccupancy = 1,
41__global__ void
42#if CK_USE_LAUNCH_BOUNDS
43__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44#endif
45 // __attribute__((amdgpu_waves_per_eu(1, 1)))
46 kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47{
48#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
50 {
51 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52
53 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
58 karg.p_max_token_id,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
61 karg.p_ds_grid,
62 karg.p_c_grid,
63 karg.p_a_scale_grid,
64 karg.p_b_scale_grid,
65 p_shared,
66 karg,
67 karg.a_element_op,
68 karg.b_element_op,
69 karg.c_element_op);
70 }
71#else
72 ignore = karg;
73#endif // end of if (defined(__gfx9__))
74}
75
76template <typename GridwiseGemm,
77 bool HasMainKBlockLoop,
78 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79 index_t MinimumOccupancy = 1,
81__global__ void
82#if CK_USE_LAUNCH_BOUNDS
83__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84#endif
85 // __attribute__((amdgpu_waves_per_eu(1, 1)))
86 kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
87{
88#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
90 {
91 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
93
94 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
95
96 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
97 karg.p_sorted_token_ids,
98 karg.p_sorted_expert_ids,
99 karg.p_max_token_id,
100 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
101 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102 karg.p_ds_grid,
103 karg.p_c_grid,
104 karg.p_a_scale_grid,
105 karg.p_b_scale_grid,
106 p_shared,
107 p_shared1,
108 karg,
109 karg.a_element_op,
110 karg.b_element_op,
111 karg.c_element_op);
112 }
113#else
114 ignore = karg;
115#endif // end of if (defined(__gfx9__))
116}
117
118template <typename ALayout,
119 typename BLayout,
120 typename DsLayout,
121 typename CLayout,
122 typename ADataType,
123 typename BDataType,
124 typename AccDataType,
125 typename CShuffleDataType,
126 typename DsDataType,
127 typename CDataType,
128 typename AElementwiseOperation,
129 typename BElementwiseOperation,
130 typename CElementwiseOperation,
132 index_t BlockSize,
133 index_t ScaleBlockM,
134 index_t ScaleBlockN,
135 index_t ScaleBlockK,
136 index_t MPerBlock,
137 index_t NPerBlock,
138 index_t KPerBlock,
139 index_t AK1Value,
140 index_t BK1Value,
141 index_t MPerXdl,
142 index_t NPerXdl,
143 index_t MXdlPerWave,
144 index_t NXdlPerWave,
145 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146 typename ABlockTransferThreadClusterArrangeOrder,
147 typename ABlockTransferSrcAccessOrder,
148 index_t ABlockTransferSrcVectorDim,
149 index_t ABlockTransferSrcScalarPerVector,
150 index_t ABlockTransferDstScalarPerVector_AK1,
151 bool AThreadTransferSrcResetCoordinateAfterRun,
152 index_t ABlockLdsExtraM,
153 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154 typename BBlockTransferThreadClusterArrangeOrder,
155 typename BBlockTransferSrcAccessOrder,
156 index_t BBlockTransferSrcVectorDim,
157 index_t BBlockTransferSrcScalarPerVector,
158 index_t BBlockTransferDstScalarPerVector_BK1,
159 bool BThreadTransferSrcResetCoordinateAfterRun,
160 index_t BBlockLdsExtraN,
161 index_t CShuffleMXdlPerWavePerShuffle,
162 index_t CShuffleNXdlPerWavePerShuffle,
163 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
164 typename CDEShuffleBlockTransferScalarPerVectors,
167 index_t ActivationOperation = 0,
168 bool NSwizzle = false,
169 bool IsInputGemm = true,
170 bool MulRoutedWeight = true,
171 typename IndexType = index_t,
172 typename ComputeTypeA = CDataType,
173 typename ComputeTypeB = ComputeTypeA,
174 typename LDSTypeA = ADataType,
175 typename LDSTypeB = BDataType>
177{
178 using AScaleType = float;
179 using BScaleType = float;
180
181 static constexpr auto I0 = Number<0>{};
182 static constexpr auto I1 = Number<1>{};
183 static constexpr auto I2 = Number<2>{};
184 static constexpr auto I3 = Number<3>{};
185 static constexpr auto I4 = Number<4>{};
186 static constexpr auto I5 = Number<5>{};
187 static constexpr auto I6 = Number<6>{};
188 static constexpr auto I7 = Number<7>{};
189
191 CDEShuffleBlockTransferScalarPerVectors{}[I0];
192 // K1 should be Number<...>
193 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
194 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
195 static constexpr auto AK1Number = Number<AK1Value>{};
196 static constexpr auto BK1Number = Number<BK1Value>{};
197 static constexpr auto BlockSizeNumber = Number<BlockSize>{};
198
199 static constexpr index_t NumDTensor = DsDataType::Size();
200
202 static constexpr index_t KPack =
204 static constexpr index_t KGroup = []() {
206 // On gfx950, we have a mfma that required 32 f8 elements as input,
207 // splited into 2 groups of 16 f8 elements.
208 // the 2 groups is not contiguous in the B preshuffed layout.
209 // and we do not want it to be contiguous in the B preshuffled layout
210 // because a memory instruction can only read 16 f8 elements at a time.
211 return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
212 else
213 return 1;
214 }();
215 static constexpr index_t KLane =
217 static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
218 static constexpr index_t NLane = NPerXdl;
219 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
220 // static constexpr index_t NumTokens = 1;
221 static constexpr index_t SortedTileSize = MPerBlock;
222
223 static constexpr auto MakeDsGridPointer()
224 {
225 return generate_tuple(
226 [&](auto i) {
227 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
228
229 return static_cast<const DDataType*>(nullptr);
230 },
232 }
233
234 using DsGridPointer = decltype(MakeDsGridPointer());
235
237
238 static constexpr index_t APackedSize = []() {
240 return 2;
241 else
242 return 1;
243 }();
244
245 static constexpr index_t BPackedSize = []() {
247 return 2;
248 else
249 return 1;
250 }();
251
252 __host__ static auto CalculateGridSize(index_t M, index_t N)
253 {
254 const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
255 const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
256 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
257 const index_t gridy = NSwizzle ? 1 : mblock;
258 return std::make_tuple(gridx, gridy, 1);
259 }
260
261 __host__ __device__ static auto CalculateMPadded(index_t M)
262 {
263 return math::integer_least_multiple(M, MPerBlock);
264 }
265
266 __host__ __device__ static auto CalculateNPadded(index_t N)
267 {
268 return math::integer_least_multiple(N, NPerBlock);
269 }
270
271 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
272 {
274 }
275 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
276 {
278 }
279
280 __host__ __device__ static auto CalculateKPadded(index_t K)
281 {
282 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
283 }
284
285 __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
286 {
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
289 }
290
291 __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
292 {
293 auto K_t = K_Batch * KPerBlock;
294 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
295 }
296
297 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
298 {
299 auto K_t = K_Batch * KPerBlock;
300 return (K + K_t - 1) / K_t * KPerBlock;
301 }
302
303 __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
304 {
305 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
306 auto K_t = K_Batch * KReadVec;
307 return (K + K_t - 1) / K_t * KReadVec;
308 }
309
310 __host__ __device__ static auto CalculateMBlock(index_t M)
311 {
312 return math::integer_divide_ceil(M, MPerBlock);
313 }
314
315 __host__ __device__ static auto CalculateNBlock(index_t N)
316 {
317 return math::integer_divide_ceil(N, NPerBlock);
318 }
319
320 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
321 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
322 {
323 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
324 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
325
327 TileDesc_K0_MN_K1{},
333 }
334
335 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
336 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
337 {
338 const auto a_grid_desc_mraw_kraw = [&]() {
340 {
341 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
342 }
344 {
345 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
346 }
347 }();
348
349 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
350
351 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
352 GemmSpec == GemmSpecialization::MNKPadding)
353 {
354 // pad both M and K
355 const auto a_grid_desc_m_k =
356 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
358 make_right_pad_transform(K, KPad - K)),
361
362 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
363 a_grid_desc_m_k,
368
369 return a_grid_desc_ak0_m_ak1;
370 }
371 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
372 GemmSpec == GemmSpecialization::MNPadding)
373 {
374 // pad M, but not K
375 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
376 a_grid_desc_mraw_kraw,
378 make_right_pad_transform(M, MPad - M)),
381
382 return a_grid_desc_ak0_m_ak1;
383 }
384 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
385 GemmSpec == GemmSpecialization::NKPadding)
386 {
387 // pad K, but not M
388 const auto a_grid_desc_m_k = transform_tensor_descriptor(
389 a_grid_desc_mraw_kraw,
393
394 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
395 a_grid_desc_m_k,
400
401 return a_grid_desc_ak0_m_ak1;
402 }
403 else
404 {
405 // not pad M or K
406 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
407 a_grid_desc_mraw_kraw,
412
413 return a_grid_desc_ak0_m_ak1;
414 }
415 }
416
417 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
418 {
419 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
420 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
421 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
423 make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
424 make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
425 }
426
427 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
428 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
429 {
430 const auto b_grid_desc_nraw_kraw = [&]() {
432 {
433 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
434 }
436 {
437 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
438 }
439 }();
440
441 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
442
444 GemmSpec != GemmSpecialization::Default),
445 "pk_i4_t does not support padding");
446
447 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
448 GemmSpec == GemmSpecialization::MNKPadding)
449 {
450 // pad both N and K
451 const auto b_grid_desc_n_k =
452 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
454 make_right_pad_transform(K, KPad - K)),
457
458 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
459 b_grid_desc_n_k,
464
465 return b_grid_desc_bk0_n_bk1;
466 }
467 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
468 GemmSpec == GemmSpecialization::MNPadding)
469 {
470 // pad N, but not K
471 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
472 b_grid_desc_nraw_kraw,
474 make_right_pad_transform(N, NPad - N)),
477
478 return b_grid_desc_bk0_n_bk1;
479 }
480 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
481 GemmSpec == GemmSpecialization::MKPadding)
482 {
483 // pad K, but not N
484 const auto b_grid_desc_n_k = transform_tensor_descriptor(
485 b_grid_desc_nraw_kraw,
489
490 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
491 b_grid_desc_n_k,
496
497 return b_grid_desc_bk0_n_bk1;
498 }
499 else
500 {
501 // not pad N or K
502 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
503 b_grid_desc_nraw_kraw,
508
509 return b_grid_desc_bk0_n_bk1;
510 }
511 }
512
513 template <typename ABlockDesc_AK0_M_AK1>
514 __host__ __device__ static constexpr auto
515 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
516 {
517 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
518
519 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
520 }
521
522 template <typename BBlockDesc_BK0_N_BK1>
523 __host__ __device__ static constexpr auto
524 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
525 {
526 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
527 }
528
529 template <typename ELayout>
530 __host__ __device__ static auto MakeCGridDescriptor_M_N(
531 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
532 {
533 const auto c_grid_desc_mraw_nraw = [&]() {
535 {
536 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
537 }
539 {
540 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
541 }
542 }();
543
544 // pad M and N
545 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
547 make_right_pad_transform(N, NPad - N)),
550 }
551
552 template <typename DLayout>
553 __host__ __device__ static auto
555 {
556 const auto c_grid_desc_mraw_nraw = [&]() {
558 {
559 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
560 }
562 {
563 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
564 }
565 }();
566
567 // pad M and N
568 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
570 make_right_pad_transform(N, NPad - N)),
573 }
574
575 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
576 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
577 {
578 return generate_tuple(
579 [&](auto i) {
580 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
581 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
582 },
584 }
585
586 template <typename DsGridDesc>
588 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
589 {
590 return generate_tuple(
591 [&](auto i) {
593 ds_grid_desc_m_n[i], MBlock, NBlock);
594 },
596 }
597
598 using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
599
600 struct Problem
601 {
602 __host__ __device__ Problem(index_t NumTokens_,
603 index_t TopK_,
604 index_t M_,
605 index_t N_,
606 index_t K_,
607 index_t StrideA_,
608 index_t StrideB_,
609 std::array<index_t, NumDTensor> StrideDs_,
610 index_t StrideC_,
611 index_t KBatch_)
612 : NumTokens{NumTokens_},
613 TopK{TopK_},
614 M{M_},
615 N{N_},
616 K{K_},
617 StrideA{StrideA_},
618 StrideB{StrideB_},
619 StrideDs{StrideDs_},
620 StrideC{StrideC_},
621 KBatch{KBatch_},
624 KRead{CalculateKRead(K_, KBatch_)},
625 KPadded{CalculateKPadded(K_, KBatch_)},
626 AK0{CalculateAK0Padded(K_, KBatch_)},
627 BK0{CalculateBK0Padded(K_, KBatch_)},
630 {
631 }
632
633 __host__ void Print() const
634 {
635 std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
636 << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
637 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
638 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
639 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
640 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
641 << "NBlock: " << NBlock << "}" << std::endl;
642 }
643
651 std::array<index_t, NumDTensor> StrideDs;
662 };
663
664 // Argument
666 {
667 __host__ Argument(const index_t* p_sorted_token_ids_,
668 const index_t* p_sorted_expert_ids_,
669 const index_t* p_max_token_id_,
670 const ADataType* p_a_grid_,
671 const BDataType* p_b_grid_,
672 std::array<const void*, NumDTensor> p_ds_grid_,
673 CDataType* p_c_grid_,
674 index_t NumTokens_,
675 index_t TopK_,
676 index_t M_,
677 index_t N_,
678 index_t K_,
679 index_t StrideA_,
680 index_t StrideB_,
681 std::array<index_t, NumDTensor> StrideDs_,
682 index_t StrideC_,
683 const AScaleType* p_a_scale_grid_,
684 const BScaleType* p_b_scale_grid_,
685 index_t k_batch_,
686 AElementwiseOperation a_element_op_,
687 BElementwiseOperation b_element_op_,
688 CElementwiseOperation c_element_op_)
689 : Problem{NumTokens_,
690 TopK_,
691 M_,
692 N_,
693 K_,
694 StrideA_,
695 StrideB_,
696 StrideDs_,
697 StrideC_,
698 k_batch_},
699 p_sorted_token_ids{p_sorted_token_ids_},
700 p_sorted_expert_ids{p_sorted_expert_ids_},
701 p_max_token_id{p_max_token_id_},
702 p_a_grid{p_a_grid_},
703 p_b_grid{p_b_grid_},
704 p_ds_grid{},
705 p_c_grid{p_c_grid_},
706 p_a_scale_grid{p_a_scale_grid_},
707 p_b_scale_grid{p_b_scale_grid_},
708 a_element_op{a_element_op_},
709 b_element_op{b_element_op_},
710 c_element_op{c_element_op_}
711 {
712
713 // populate pointer, desc for Ds
714 static_for<0, NumDTensor, 1>{}([&](auto i) {
715 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
716
717 // D pointer
718 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
719 });
720 }
721
725 const ADataType* p_a_grid;
726 const BDataType* p_b_grid;
728 CDataType* p_c_grid;
729
732
733 const AElementwiseOperation a_element_op;
734 const BElementwiseOperation b_element_op;
735 const CElementwiseOperation c_element_op;
736 };
737
739 {
740 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
741 {
743 {
744 a_k_split_offset = k_id * karg.KRead / APackedSize;
745 }
747 {
748 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
749 }
750
752 {
753 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
754 }
756 {
757 // KPack * NLane * KLane * K0 * N0
758 b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
759 }
760
761 if(k_id < karg.KBatch - 1)
762 {
763 karg.K = karg.KRead;
764 }
765 else
766 {
767 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
768 }
769 }
770
773 };
774
775 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
776 {
777 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
778 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
779 // A matrix in LDS memory, dst of blockwise copy
780 if constexpr(ABlockLdsExtraM)
781 {
785 }
786 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
787 // in some cases.
789 {
790 constexpr auto a_lds_block_desc =
793
794 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
795 a_lds_block_desc,
801
802 return a_lds_block_desc_permuted;
803 }
804 else // ColumnMajor A
805 {
806 // kfold and mpair dimension is not always required.
807 // more dimension in merge_transform increase the difficulty of generating immarg offset
808 // for compiler.
809 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
810 constexpr auto M1 = MPerBlock / M0;
811
812 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
813 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
814 constexpr auto KThreadRead = WaveSize / MPerXdl;
815 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
816
817 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
818 ? 1
819 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
820 constexpr auto KThreadReadPerm =
821 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
822 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
823 : KThreadRead;
824
825 // 1<=mpair<=n0
826 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
827 ? 1
828 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
829 ? M0
830 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
831
832 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
836 Number<kfold * M0 / mpair>{},
838 AK1Number));
839
840 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
841 a_lds_block_desc,
846 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
853
854 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
855 a_lds_block_desc_permuted,
864 Sequence<1>{},
865 Sequence<2>{},
866 Sequence<3>{},
867 Sequence<4>{},
868 Sequence<5>{}),
870 Sequence<2>{},
873 Sequence<6>{},
874 Sequence<7>{}));
875
876 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
877 a_lds_block_desc_unmerged,
880 Number<KThreadWrite / kfold / KThreadReadPerm>{},
888
889 return a_lds_block_desc_ak0_m_ak1;
890 }
891 }
892
893 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
894 {
895 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
898 }
899
901 {
902 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
903
904 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
908 I1,
910
911 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
912 }
913
916 BlkGemmPipelineVer,
917 BlkGemmPipeSched,
918 BlockSize,
919 ADataType,
920 BDataType,
921 ComputeTypeA,
922 AccDataType,
929 ABlockTransferSrcScalarPerVector,
930 BBlockTransferSrcScalarPerVector,
931 MPerBlock,
932 NPerBlock,
933 KPerBlock,
934 ScaleBlockM,
935 ScaleBlockN,
936 ScaleBlockK,
937 MPerXdl,
938 NPerXdl,
939 MXdlPerWave,
940 NXdlPerWave,
941 KPack,
942 IsInputGemm>())>;
943
944 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
945 {
946 // LDS allocation for A and B: be careful of alignment
947 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
948 // lds max alignment
949 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
950
951 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
952 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
953
954 // LDS allocation for C shuffle in LDS
955 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
957
958 constexpr auto c_block_size =
959 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
960
961 return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
962 c_block_size * sizeof(CShuffleDataType));
963 }
964
966
967 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
968 __host__ static constexpr bool CheckValidity(const Argument& karg)
969 {
970 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
971 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
972 "Invalid tuning param!");
973
979 {
980 if(!(karg.M % MPerBlock == 0))
981 {
982#if DEBUG_LOG
983 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
984 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
985 << std::endl;
986
987#endif // DEBUG_LOG
988 return false;
989 }
990 }
991
997 {
998 if(!(karg.N % NPerBlock == 0))
999 {
1000#if DEBUG_LOG
1001 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1002 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1003 << std::endl;
1004
1005#endif // DEBUG_LOG
1006 return false;
1007 }
1008 }
1009
1010 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1014 {
1015
1016 auto K_t = karg.KBatch * KPerBlock;
1017 if(!(karg.K % K_t == 0))
1018 {
1019#if DEBUG_LOG
1020 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1021 << karg.K << " " << __FILE__ << ":" << __LINE__
1022 << ", in function: " << __func__ << std::endl;
1023
1024#endif // DEBUG_LOG
1025 return false;
1026 }
1027 }
1028 else
1029 {
1030 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1031 auto K_t = karg.KBatch * KReadVec;
1032 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1033 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1034 {
1035 return false;
1036 }
1037 }
1038
1040 {
1041 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1042 {
1043#if DEBUG_LOG
1044 std::cout << "Arg K (" << karg.K
1045 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1046 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1047 << __LINE__ << ", in function: " << __func__ << std::endl;
1048
1049#endif // DEBUG_LOG
1050 return false;
1051 }
1052 }
1053 else
1054 {
1055 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1056 {
1057#if DEBUG_LOG
1058 std::cout << "Arg M (" << karg.M
1059 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1060 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1061 << __LINE__ << ", in function: " << __func__ << std::endl;
1062
1063#endif // DEBUG_LOG
1064 return false;
1065 }
1066 }
1067
1069 {
1070 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1071 {
1072#if DEBUG_LOG
1073 std::cout << "Arg N (" << karg.N
1074 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1075 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1076 << __LINE__ << ", in function: " << __func__ << std::endl;
1077
1078#endif // DEBUG_LOG
1079 return false;
1080 }
1081 }
1082 else
1083 {
1084 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1085 {
1086#if DEBUG_LOG
1087 std::cout << "Arg K (" << karg.K
1088 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1089 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1090 << __LINE__ << ", in function: " << __func__ << std::endl;
1091
1092#endif // DEBUG_LOG
1093 return false;
1094 }
1095 }
1096
1098 {
1100 {
1101#if DEBUG_LOG
1102 std::cout << "Arg N (" << karg.N
1103 << ") value is not a multiple of "
1104 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1105 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1106 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1107
1108#endif // DEBUG_LOG
1109 return false;
1110 }
1111 }
1112 else
1113 {
1115 {
1116#if DEBUG_LOG
1117 std::cout << "Arg M (" << karg.M
1118 << ") value is not a multiple of "
1119 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1120 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1121 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1122
1123#endif // DEBUG_LOG
1124 return false;
1125 }
1126 }
1127
1128 // check gridwise gemm pipeline
1129#if 0
1130 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1131
1132 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1133 {
1134 return false;
1135 }
1136#endif
1137 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1138 return true;
1139 }
1140
1141 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1142 {
1143 const index_t num_loop = K / KPerBlock;
1144
1145 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1146 }
1147
1148 __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1149 {
1150 const index_t num_loop = K / KPerBlock;
1151
1152 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1153 }
1154
1155 template <typename CGridDesc>
1157 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1158 {
1159 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1160 c_grid_desc_m_n,
1165
1166 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1167 }
1168
1169 // return block_id to C matrix tile idx (m0, n0) mapping
1170 // if arch = gfx942
1171 // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1172 // NPerBlock>;
1173
1174 template <bool HasMainKBlockLoop,
1175 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1176 TailNumber TailNum = TailNumber::Odd>
1177 __device__ static void Run(const index_t* p_sorted_token_ids,
1178 const index_t* p_sorted_expert_ids,
1179 const index_t* p_max_token_id,
1180 const ADataType* p_a_grid,
1181 const BDataType* p_b_grid,
1182 DsGridPointer& p_ds_grid,
1183 CDataType* p_c_grid,
1184 const AScaleType* p_a_scale_grid,
1185 const BScaleType* p_b_scale_grid,
1186 void* p_shared,
1187 const Problem& problem,
1188 AElementwiseOperation a_element_op,
1189 BElementwiseOperation b_element_op,
1190 CElementwiseOperation c_element_op)
1191 {
1192 ignore = b_element_op;
1193 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1194 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1195 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1196 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1197 problem.MPadded,
1198 problem.K,
1199 problem.KPadded,
1200 problem.StrideA,
1201 problem.AK0);
1202 const auto b_grid_desc_bpreshuffled =
1203 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1204 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1205 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1206 problem.MPadded,
1207 problem.N,
1208 problem.NPadded,
1209 problem.StrideC);
1210
1211 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1212 make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1213 : problem.NumTokens * problem.TopK,
1214 ScaleBlockM),
1215 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1216 make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1217 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1218 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1219 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1220 make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1221
1222 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1224 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1225 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1226 // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1227 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1228 if(expert_block_id * MPerBlock >= max_token_id)
1229 return;
1230 const index_t expert_id =
1231 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1232 const auto block_mn = [&]() -> std::pair<int, int> {
1233 if constexpr(NSwizzle)
1234 {
1235 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1236 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1237 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1238 const index_t expert_swizzle =
1239 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1240 const index_t bid_new = blockIdx.x - prefix_block;
1241 const index_t nid = __builtin_amdgcn_readfirstlane(
1242 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1243 const index_t mid =
1244 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1245 return {nid, mid};
1246 }
1247 else
1248 {
1249 return {blockIdx.x, blockIdx.y};
1250 }
1251 }();
1252 const index_t block_n_id = block_mn.first;
1253 const index_t block_m_id = block_mn.second;
1254 const index_t token0 =
1255 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1256
1257 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1258 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1259 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1260 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1261 constexpr auto AKThreads = AK0Threads * AK1Threads;
1262 constexpr auto AMRepeats = MPerBlock / AMThreads;
1263 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1264
1265 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1266 return;
1268 static_for<0, AMRepeats, 1>{}([&](auto m0) {
1269 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1270 index_t token_offset = fused_token & 0xffffff;
1271 if constexpr(!IsInputGemm)
1272 {
1273 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1274 }
1275 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1276 });
1277 const index_t expert_stride =
1278 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1279 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1280 math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1281 math::integer_divide_ceil(problem.K, ScaleBlockK));
1282
1283 // N0, K0, Blocksize*KPack
1284 const index_t n_block_data_idx_on_grid =
1285 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1286
1287 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1288 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1289 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1290 p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1291 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1292
1293 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1294 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1295 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1296 p_b_scale_grid + expert_id * expert_scale_stride,
1297 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1298
1299 // A matrix in LDS memory, dst of blockwise copy
1300 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1301
1302 // B matrix in LDS memory, dst of blockwise copy
1303 // dummy
1304 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1305 // A matrix blockwise copy
1306 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1308 AElementwiseOperation,
1312 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1313 ABlockTransferThreadClusterArrangeOrder,
1314 ADataType,
1315 LDSTypeA,
1316 decltype(a_grid_desc_ak0_m_ak1),
1317 decltype(a_block_desc_ak0_m_ak1),
1318 ABlockTransferSrcAccessOrder,
1320 ABlockTransferSrcVectorDim,
1321 2,
1322 ABlockTransferSrcScalarPerVector,
1323 ABlockTransferDstScalarPerVector_AK1,
1324 1,
1325 1,
1326 AThreadTransferSrcResetCoordinateAfterRun,
1327 true,
1328 IndexType,
1329 1,
1330 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1331 make_multi_index(0, 0, 0),
1332 a_element_op,
1333 a_block_desc_ak0_m_ak1,
1334 make_multi_index(0, 0, 0),
1336 gather_offsets);
1337
1338 // Thread-wise copy
1339 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1341 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1342
1343 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1344 BDataType,
1345 BDataType,
1346 decltype(b_grid_desc_bpreshuffled),
1347 decltype(b_block_desc_bk0_n_bk1),
1350 3,
1351 BBlockTransferSrcScalarPerVector,
1352 BThreadTransferSrcResetCoordinateAfterRun,
1353 true>(b_grid_desc_bpreshuffled,
1354 make_multi_index(n_block_data_idx_on_grid,
1356 0,
1357 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1358
1359 // LDS allocation for A and B: be careful of alignment
1360 // Cast after lds
1362 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1363
1364 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1365 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1366
1367 // Blockwise GEMM pipeline
1368 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1369 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1370 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1371 decltype(c_thread_buf) c_thread_buf_up;
1372
1373 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1374 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1375 KPerBlock);
1376
1377 constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1378 constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1379 constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1380
1381 // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1382 // ScaleSliceSizeK is first dimension in C scale for packed math
1383 constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1385
1386 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1387 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1388 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1389 auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1390 (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1391
1392 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1394
1395 constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1397
1398 // get each thread's offset in the scale tensor
1399 // A scale
1400 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1401
1402 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1403 return;
1405 static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1406 const index_t fused_token =
1407 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1408 index_t token_offset = fused_token & 0xffffff;
1409 if constexpr(!IsInputGemm)
1410 {
1411 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1412 }
1413 scale_gather_offsets(m0) =
1414 token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1415 });
1416
1417 auto a_scale_thread_copy =
1419 AScaleType,
1420 decltype(a_scale_grid_desc_am_ak),
1421 decltype(a_scale_thread_desc),
1424 1,
1425 ScaleSliceSizeK,
1426 1,
1427 false,
1428 MXdlPerWave>(
1429 a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1430
1431 auto b_scale_thread_copy =
1433 BScaleType,
1434 decltype(b_scale_grid_desc_bn_ak),
1435 decltype(b_scale_thread_desc),
1438 1,
1439 ScaleSliceSizeK,
1440 1,
1441 false>(
1442 b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1443
1444 // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1445 constexpr auto a_scale_thread_slice_copy_step =
1446 make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1447 constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1448
1449 constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1450 if constexpr(IsInputGemm)
1451 {
1452 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1453 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1454 p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1455 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1456 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1457 BDataType,
1458 BDataType,
1459 decltype(b_grid_desc_bpreshuffled),
1460 decltype(b_block_desc_bk0_n_bk1),
1463 3,
1464 BBlockTransferSrcScalarPerVector,
1465 BThreadTransferSrcResetCoordinateAfterRun,
1466 true>(b_grid_desc_bpreshuffled,
1467 make_multi_index(n_block_data_idx_on_grid,
1469 0,
1470 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1471 const BScaleType* p_b_scale_grid_up =
1472 p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1473 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1474 p_b_scale_grid_up + expert_id * expert_scale_stride,
1475 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1476 auto b_scale_thread_copy_up =
1478 BScaleType,
1479 decltype(b_scale_grid_desc_bn_ak),
1480 decltype(b_scale_thread_desc),
1483 1,
1484 ScaleSliceSizeK,
1485 1,
1486 false>(
1487 b_scale_grid_desc_bn_ak,
1488 make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1489
1490 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1491 a_grid_desc_ak0_m_ak1,
1492 a_block_desc_ak0_m_ak1,
1493 a_blockwise_copy,
1494 a_grid_buf,
1495 a_block_buf,
1496 a_block_slice_copy_step,
1497
1498 b_grid_desc_bpreshuffled,
1499 b_block_desc_bk0_n_bk1,
1500 b_blockwise_copy,
1501 b_blockwise_copy_up,
1502 b_grid_buf,
1503 b_grid_buf_up,
1504 b_block_buf,
1505 b_block_slice_copy_step,
1506
1507 c_scale_thread_desc,
1508 c_thread_buf,
1509 c_thread_buf_up,
1510
1511 a_scale_grid_desc_am_ak,
1512 a_scale_thread_desc,
1513 a_scale_thread_copy,
1514 a_scale_grid_buf,
1515 a_scale_thread_slice_copy_step,
1516
1517 b_scale_grid_desc_bn_ak,
1518 b_scale_thread_desc,
1519 b_scale_thread_copy,
1520 b_scale_thread_copy_up,
1521 b_scale_grid_buf,
1522 b_scale_grid_buf_up,
1523 b_scale_thread_slice_copy_step,
1524
1525 num_k_block_main_loop);
1526 }
1527 else
1528 {
1529 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1530 a_grid_desc_ak0_m_ak1,
1531 a_block_desc_ak0_m_ak1,
1532 a_blockwise_copy,
1533 a_grid_buf,
1534 a_block_buf,
1535 a_block_slice_copy_step,
1536
1537 b_grid_desc_bpreshuffled,
1538 b_block_desc_bk0_n_bk1,
1539 b_blockwise_copy,
1540 b_grid_buf,
1541 b_block_buf,
1542 b_block_slice_copy_step,
1543
1544 c_scale_thread_desc,
1545 c_thread_buf,
1546
1547 a_scale_grid_desc_am_ak,
1548 a_scale_thread_desc,
1549 a_scale_thread_copy,
1550 a_scale_grid_buf,
1551 a_scale_thread_slice_copy_step,
1552
1553 b_scale_grid_desc_bn_ak,
1554 b_scale_thread_desc,
1555 b_scale_thread_copy,
1556 b_scale_grid_buf,
1557 b_scale_thread_slice_copy_step,
1558
1559 num_k_block_main_loop);
1560 }
1561
1562 // shuffle C and write out
1563 {
1564 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1565 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1566 "wrong!");
1567
1568 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1569
1570 // transposed XDL
1571 // TODO: hacky, fix it!
1572 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1573 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1574
1575 // TODO: hacky, fix it!
1576 // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1577 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1578 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1579
1580 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1581 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1582 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1583 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1584 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1585 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1586 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1587 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1588
1589 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1590 static_assert(M0 * M1 * M2 == MPerBlock);
1591 static_assert(N4 == 4 || N4 == 8);
1592 const index_t m1 = get_warp_local_1d_id() / NWave;
1593 const index_t m2 = threadIdx.x % get_warp_size() % M2;
1594
1595 float topk_weight;
1596 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1597 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1598 if constexpr(MulRoutedWeight)
1599 {
1600 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1601 topk_weight = p_ds_grid[I0][m_pos];
1602 }
1603 static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1604 static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1605 constexpr index_t c_offset =
1606 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1607 make_tuple(m0, n0, n2 * N4 + n4));
1608 constexpr auto cidx = Number<c_offset>{};
1609 if constexpr(IsInputGemm) // gu fusion, elementwise
1610 {
1611 if constexpr(ActivationOperation == Activation::silu_and_mul)
1612 {
1613 float gate = c_thread_buf[cidx];
1614 float up = c_thread_buf_up[cidx];
1615 if constexpr(MulRoutedWeight)
1616 {
1617 gate = gate * topk_weight;
1618 up = up * topk_weight;
1619 }
1621 {
1622 gate *= 16;
1623 up *= 16;
1624 }
1626 c_thread_buf(cidx) = gate * up;
1627 }
1628 else if(ActivationOperation == Activation::gelu_and_mul)
1629 {
1630 float gate = c_thread_buf[cidx];
1631 float up = c_thread_buf_up[cidx];
1632 if constexpr(MulRoutedWeight)
1633 {
1634 gate = gate * topk_weight;
1635 up = up * topk_weight;
1636 }
1638 {
1639 gate *= 16;
1640 up *= 16;
1641 }
1643 c_thread_buf(cidx) = gate * up;
1644 }
1645 }
1646 else
1647 {
1648 if constexpr(MulRoutedWeight)
1649 {
1650 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1651 }
1652 }
1653 });
1654 });
1655 });
1656 });
1657
1658 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1660
1661 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662 static_cast<CShuffleDataType*>(p_shared),
1663 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1664
1665 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1666 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1667 make_tuple(
1670 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1671 M1, // M1 = MWave
1672 M2)), // M2 = MPerXdl
1675 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1676 N1, // N1 = NWave
1677 N2, // N2 * N3 * N4 = NPerXdl
1678 N3,
1679 N4))),
1681 make_tuple(
1683
1684 // calculate origin of thread output tensor on global memory
1685 // blockwise GEMM c matrix starting index
1686 const auto c_thread_mtx_on_block =
1687 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1688
1689 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1690 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1691
1692 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1697
1698 const auto m_thread_data_on_block_idx =
1699 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1700 make_multi_index(m_thread_data_on_block));
1701
1702 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1704 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1707
1708 const auto n_thread_data_on_block_idx =
1709 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1710 make_multi_index(n_thread_data_on_block));
1711
1712 // shuffle: threadwise copy C from VGPR to LDS
1713 auto c_thread_copy_vgpr_to_lds =
1715 CShuffleDataType,
1716 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1717 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1719 Sequence<CShuffleMXdlPerWavePerShuffle,
1720 CShuffleNXdlPerWavePerShuffle,
1721 I1,
1722 I1,
1723 I1,
1724 N2,
1725 I1,
1726 N4>,
1728 7,
1729 1,
1731 1,
1732 true>{
1733 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1735 0,
1736 m_thread_data_on_block_idx[I1],
1737 n_thread_data_on_block_idx[I1],
1738 m_thread_data_on_block_idx[I2],
1739 n_thread_data_on_block_idx[I2],
1740 n_thread_data_on_block_idx[I3],
1741 n_thread_data_on_block_idx[I4]),
1743
1744 using EDataType = CDataType;
1745
1746 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1747 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1748
1749 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1751 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1752
1753 const auto ds_grid_buf = generate_tuple(
1754 [&](auto i) {
1755 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1756 const DDataType* ptr_ = p_ds_grid[i];
1757 // hack logic here to support different kind of strides. todo fix it.
1758 // ascale t, 1; bscale E, N, 1, move ptr to E
1760 ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1761 },
1763
1764 // tuple of reference to C/Ds tensor descriptors
1765 const auto c_ds_desc_refs = concat_tuple_of_reference(
1766 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1767 generate_tie([&](auto i) -> const auto& // return type should be reference
1768 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1770
1771 // tuple of reference to C/Ds tensor descriptors
1772 const auto c_ds_buf_refs = concat_tuple_of_reference(
1773 tie(c_shuffle_block_buf),
1774 generate_tie([&](auto i) -> const auto& // return type should be reference
1775 { return ds_grid_buf[i]; },
1777
1778 // tuple of starting index of C/Ds blockwise copy
1779 const auto idx_c_ds_block_begin =
1782 [&](auto) {
1783 return make_multi_index(block_m_id, 0, block_n_id, 0);
1784 // return make_multi_index(block_work_idx[I0], 0,
1785 // block_work_idx[I1], 0);
1786 },
1788
1789 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1790 c_grid_desc_mblock_mperblock_nblock_nperblock;
1791
1792 using CDEBlockTransferCluster =
1793 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1794 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1795 constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1796 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1798 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1800 decltype(c_ds_desc_refs),
1801 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1802 CElementwiseOperation,
1803 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1804 // support arbitray type
1805 Sequence<1,
1806 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1807 1,
1808 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1809 CDEBlockTransferCluster,
1810 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1811 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1812 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1813 3, // index_t SrcVectorDim,
1814 3, // index_t DstVectorDim,
1815 CDEShuffleBlockTransferScalarPerVectors,
1820 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1821 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1822 IndexType,
1823 1, // ScatterDim
1824 true, // OutputScatter: false, only use scatter weights
1825 scatter_weight_idx // ScatterWeightIdx: ascale
1826 >{c_ds_desc_refs,
1827 idx_c_ds_block_begin,
1828 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1829 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1830 c_element_op};
1831
1833 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1834 // space filling curve for threadwise C in VGPR
1835 constexpr auto sfc_c_vgpr =
1838 Sequence<CShuffleMXdlPerWavePerShuffle,
1839 CShuffleNXdlPerWavePerShuffle,
1840 1,
1841 1,
1842 1,
1843 N2,
1844 1,
1845 N4>>{};
1846
1847 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1848
1849 // space filling curve for shuffled blockwise C/D/E
1850 constexpr auto sfc_cde_block =
1853 Sequence<1,
1854 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1855 1,
1856 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1857
1858 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1859 constexpr auto EMThreads =
1860 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1861 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1862 constexpr auto ENThreads =
1863 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1864 static_for<0, num_access, 1>{}([&](auto access_id) {
1865 // make sure it's safe to write to LDS
1867
1868 auto dstidx = sfc_cde_block.GetIndex(access_id);
1869 const index_t c_token_pos =
1870 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1871 static_for<0, EMRepeats, 1>{}([&](auto m0) {
1872 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1873 index_t token_offset = fused_token & 0xffffff;
1874 if constexpr(IsInputGemm)
1875 {
1876 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1877 }
1878 scatter_offsets(m0) = token_offset * problem.N;
1879 });
1880
1882
1883 // each thread write its data from VGPR to LDS
1884 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1885 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1886 c_thread_buf,
1887 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1888 c_shuffle_block_buf);
1889
1890 // make sure it's safe to read from LDS
1892
1893 // each block copy its data from LDS to global
1894 cde_block_copy_lds_and_global.Run(
1895 c_ds_desc_refs,
1896 c_ds_buf_refs,
1897 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1898 tie(c_grid_buf),
1899 scatter_offsets);
1900
1901 if constexpr(access_id < num_access - 1)
1902 {
1903 constexpr auto cde_lds_and_global_step =
1904 sfc_cde_block.GetForwardStep(access_id);
1905
1906 // move on Ds
1907 static_for<0, NumDTensor, 1>{}([&](auto i) {
1908 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1909 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1910 });
1911
1912 // move on E
1913 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1914 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1915 I0,
1916 cde_lds_and_global_step);
1917 }
1918 });
1919 }
1920 }
1921
1922 template <bool HasMainKBlockLoop,
1923 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1924 TailNumber TailNum = TailNumber::Odd>
1925 __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1926 const index_t* p_sorted_expert_ids,
1927 const index_t* p_max_token_id,
1928 const ADataType* p_a_grid,
1929 const BDataType* p_b_grid,
1930 DsGridPointer& p_ds_grid,
1931 CDataType* p_c_grid,
1932 const AScaleType* p_a_scale_grid,
1933 const BScaleType* p_b_scale_grid,
1934 void* p_shared,
1935 void* p_shared1,
1936 const Problem& problem,
1937 AElementwiseOperation a_element_op,
1938 BElementwiseOperation b_element_op,
1939 CElementwiseOperation c_element_op)
1940 {
1941 ignore = b_element_op;
1942 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1943 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1944 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1945 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1946 problem.MPadded,
1947 problem.K,
1948 problem.KPadded,
1949 problem.StrideA,
1950 problem.AK0);
1951 const auto b_grid_desc_bpreshuffled =
1952 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1953 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1954 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1955 problem.MPadded,
1956 problem.N,
1957 problem.NPadded,
1958 problem.StrideC);
1959
1960 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1961 make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1962 : problem.NumTokens * problem.TopK,
1963 ScaleBlockM),
1964 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1965 make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1966 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1967 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1968 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1969 make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1970 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1972 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1973 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1974 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1975 if(expert_block_id * MPerBlock >= max_token_id)
1976 return;
1977 const index_t expert_id =
1978 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1979 const auto block_mn = [&]() -> std::pair<int, int> {
1980 if constexpr(NSwizzle)
1981 {
1982 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1983 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1984 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1985 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1986 const index_t bid_new = blockIdx.x - prefix_block;
1987 const index_t nid = __builtin_amdgcn_readfirstlane(
1988 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1989 const index_t mid =
1990 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1991 return {nid, mid};
1992 }
1993 else
1994 {
1995 return {blockIdx.x, blockIdx.y};
1996 }
1997 }();
1998 const index_t block_n_id = block_mn.first;
1999 const index_t block_m_id = block_mn.second;
2000
2001 const index_t token0 =
2002 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2003
2004 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2005 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2006 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2007 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2008 constexpr auto AKThreads = AK0Threads * AK1Threads;
2009 constexpr auto AMRepeats = MPerBlock / AMThreads;
2010 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2011
2012 if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2013 token0 >= problem.NumTokens)
2014 return;
2016 gather_offsets; //= p_sorted_token_ids[token_pos];
2017 static_for<0, AMRepeats, 1>{}([&](auto m0) {
2018 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2019 index_t token_offset = fused_token & 0xffffff;
2020 if constexpr(!IsInputGemm)
2021 {
2022 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2023 }
2024 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2025 });
2026 const index_t expert_stride =
2027 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2028 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2029 math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2030 math::integer_divide_ceil(problem.K, ScaleBlockK));
2031 // N0, K0, Blocksize*KPack
2032 const index_t n_block_data_idx_on_grid =
2033 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2034
2035 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2037 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2038 p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2039 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2040
2041 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2042 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2043 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2044 p_b_scale_grid + expert_id * expert_scale_stride,
2045 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2046
2047 // A matrix in LDS memory, dst of blockwise copy
2048 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2049
2050 // B matrix in LDS memory, dst of blockwise copy
2051 // dummy
2052 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2053 // A matrix blockwise copy
2054 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2056 AElementwiseOperation,
2060 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2061 ABlockTransferThreadClusterArrangeOrder,
2062 ADataType,
2063 LDSTypeA,
2064 decltype(a_grid_desc_ak0_m_ak1),
2065 decltype(a_block_desc_ak0_m_ak1),
2066 ABlockTransferSrcAccessOrder,
2068 ABlockTransferSrcVectorDim,
2069 2,
2070 ABlockTransferSrcScalarPerVector,
2071 ABlockTransferDstScalarPerVector_AK1,
2072 1,
2073 1,
2074 AThreadTransferSrcResetCoordinateAfterRun,
2075 true,
2076 IndexType,
2077 1,
2078 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2079 make_multi_index(0, 0, 0),
2080 a_element_op,
2081 a_block_desc_ak0_m_ak1,
2082 make_multi_index(0, 0, 0),
2084 gather_offsets);
2085
2086 // Thread-wise copy
2087 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2089 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2091 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2092 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2093
2094 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2095 BDataType,
2096 BDataType,
2097 decltype(b_grid_desc_bpreshuffled),
2098 decltype(b_block_desc_bk0_n_bk1),
2101 3,
2102 BBlockTransferSrcScalarPerVector,
2103 BThreadTransferSrcResetCoordinateAfterRun,
2104 true>(b_grid_desc_bpreshuffled,
2105 make_multi_index(n_block_data_idx_on_grid,
2107 0,
2108 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2109
2110 // LDS allocation for A and B: be careful of alignment
2111 // Cast after lds
2112 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2113 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2114 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2115 static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2116 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2117
2118 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2119 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2120
2121 // Blockwise GEMM pipeline
2122 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2123 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2124 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2125 decltype(c_thread_buf) c_thread_buf_up;
2126
2127 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2128 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2129 KPerBlock);
2130
2131 // scale
2132 constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2133 constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2134 constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2135
2136 // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2137 // ScaleSliceSizeK is first dimension in C scale for packed math
2138 constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2140
2141 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2142 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2143 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
2144 auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
2145 (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
2146
2147 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2149
2150 constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2152
2153 // get each thread's offset in the scale tensor
2154 // A scale
2155 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2156
2157 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2158 return;
2160 static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2161 const index_t fused_token =
2162 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2163 index_t token_offset = fused_token & 0xffffff;
2164 if constexpr(!IsInputGemm)
2165 {
2166 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2167 }
2168 scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2169 math::integer_divide_ceil(problem.K, ScaleBlockK);
2170 });
2171
2172 auto a_scale_thread_copy =
2174 AScaleType,
2175 decltype(a_scale_grid_desc_am_ak),
2176 decltype(a_scale_thread_desc),
2179 1,
2180 ScaleSliceSizeK,
2181 1,
2182 false,
2183 MXdlPerWave>(
2184 a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2185
2186 auto b_scale_thread_copy =
2188 BScaleType,
2189 decltype(b_scale_grid_desc_bn_ak),
2190 decltype(b_scale_thread_desc),
2193 1,
2194 ScaleSliceSizeK,
2195 1,
2196 false>(
2197 b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2198
2199 // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2200 constexpr auto a_scale_thread_slice_copy_step =
2201 make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2202 constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2203
2204 constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2205 if constexpr(IsInputGemm)
2206 {
2207 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2208 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2209 p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2210 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2211 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2212 BDataType,
2213 BDataType,
2214 decltype(b_grid_desc_bpreshuffled),
2215 decltype(b_block_desc_bk0_n_bk1),
2218 3,
2219 BBlockTransferSrcScalarPerVector,
2220 BThreadTransferSrcResetCoordinateAfterRun,
2221 true>(b_grid_desc_bpreshuffled,
2222 make_multi_index(n_block_data_idx_on_grid,
2224 0,
2225 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2226 const BScaleType* p_b_scale_grid_up =
2227 p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2228 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2229 p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2230 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2231 auto b_scale_thread_copy_up =
2233 BScaleType,
2234 decltype(b_scale_grid_desc_bn_ak),
2235 decltype(b_scale_thread_desc),
2238 1,
2239 ScaleSliceSizeK,
2240 1,
2241 false>(
2242 b_scale_grid_desc_bn_ak,
2243 make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2244
2245 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2246 a_grid_desc_ak0_m_ak1,
2247 a_block_desc_ak0_m_ak1,
2248 a_blockwise_copy,
2249 a_grid_buf,
2250 a_block_bufs,
2251 a_block_slice_copy_step,
2252 b_grid_desc_bpreshuffled,
2253 b_block_desc_bk0_n_bk1,
2254 b_blockwise_copy,
2255 b_blockwise_copy_up,
2256 b_grid_buf,
2257 b_grid_buf_up,
2258 b_block_bufs,
2259 b_block_slice_copy_step,
2260 c_scale_thread_desc,
2261 c_thread_buf,
2262 c_thread_buf_up,
2263 a_scale_grid_desc_am_ak,
2264 a_scale_thread_desc,
2265 a_scale_thread_copy,
2266 a_scale_grid_buf,
2267 a_scale_thread_slice_copy_step,
2268 b_scale_grid_desc_bn_ak,
2269 b_scale_thread_desc,
2270 b_scale_thread_copy,
2271 b_scale_thread_copy_up,
2272 b_scale_grid_buf,
2273 b_scale_grid_buf_up,
2274 b_scale_thread_slice_copy_step,
2275 num_k_block_main_loop);
2276 }
2277 else
2278 {
2279 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2280 a_grid_desc_ak0_m_ak1,
2281 a_block_desc_ak0_m_ak1,
2282 a_blockwise_copy,
2283 a_grid_buf,
2284 a_block_bufs,
2285 a_block_slice_copy_step,
2286 b_grid_desc_bpreshuffled,
2287 b_block_desc_bk0_n_bk1,
2288 b_blockwise_copy,
2289 b_grid_buf,
2290 b_block_bufs,
2291 b_block_slice_copy_step,
2292 c_scale_thread_desc,
2293 c_thread_buf,
2294 a_scale_grid_desc_am_ak,
2295 a_scale_thread_desc,
2296 a_scale_thread_copy,
2297 a_scale_grid_buf,
2298 a_scale_thread_slice_copy_step,
2299 b_scale_grid_desc_bn_ak,
2300 b_scale_thread_desc,
2301 b_scale_thread_copy,
2302 b_scale_grid_buf,
2303 b_scale_thread_slice_copy_step,
2304 num_k_block_main_loop);
2305 }
2306
2307 // shuffle C and write out
2308 {
2309
2310 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2311 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2312 "wrong!");
2313
2314 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2315
2316 // transposed XDL
2317 // TODO: hacky, fix it!
2318 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2319 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2320
2321 // TODO: hacky, fix it!
2322 // only used to get lengths
2323 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2324 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2325
2326 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2327 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2328 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2329 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2330 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2331 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2332 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2333 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2334
2335 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2336 static_assert(M0 * M1 * M2 == MPerBlock);
2337 static_assert(N4 == 4 || N4 == 8);
2338 const index_t m1 = get_warp_local_1d_id() / NWave;
2339 const index_t m2 = threadIdx.x % get_warp_size() % M2;
2340
2341 float topk_weight;
2342 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2343 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2344 if constexpr(MulRoutedWeight)
2345 {
2346 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2347 topk_weight = p_ds_grid[I0][m_pos];
2348 }
2349 static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2350 static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2351 constexpr index_t c_offset =
2352 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2353 make_tuple(m0, n0, n2 * N4 + n4));
2354 constexpr auto cidx = Number<c_offset>{};
2355 if constexpr(IsInputGemm) // gu fusion, elementwise
2356 {
2357 if constexpr(ActivationOperation == Activation::silu_and_mul)
2358 {
2359 float gate = c_thread_buf[cidx];
2360 float up = c_thread_buf_up[cidx];
2361 if constexpr(MulRoutedWeight)
2362 {
2363 gate = gate * topk_weight;
2364 up = up * topk_weight;
2365 }
2367 {
2368 gate *= 16;
2369 up *= 16;
2370 }
2372 c_thread_buf(cidx) = gate * up;
2373 }
2374 else if(ActivationOperation == Activation::gelu_and_mul)
2375 {
2376 float gate = c_thread_buf[cidx];
2377 float up = c_thread_buf_up[cidx];
2378 if constexpr(MulRoutedWeight)
2379 {
2380 gate = gate * topk_weight;
2381 up = up * topk_weight;
2382 }
2384 {
2385 gate *= 16;
2386 up *= 16;
2387 }
2389 c_thread_buf(cidx) = gate * up;
2390 }
2391 }
2392 else
2393 {
2394 if constexpr(MulRoutedWeight)
2395 {
2396 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2397 }
2398 }
2399
2400 });
2401 });
2402 });
2403 });
2404
2405 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2407
2408 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2409 static_cast<CShuffleDataType*>(p_shared),
2410 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2411
2412 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2413 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2414 make_tuple(
2417 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2418 M1, // M1 = MWave
2419 M2)), // M2 = MPerXdl
2422 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2423 N1, // N1 = NWave
2424 N2, // N2 * N3 * N4 = NPerXdl
2425 N3,
2426 N4))),
2428 make_tuple(
2430
2431 // calculate origin of thread output tensor on global memory
2432 // blockwise GEMM c matrix starting index
2433 const auto c_thread_mtx_on_block =
2434 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2435
2436 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2437 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2438
2439 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2444
2445 const auto m_thread_data_on_block_idx =
2446 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2447 make_multi_index(m_thread_data_on_block));
2448
2449 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2451 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2454
2455 const auto n_thread_data_on_block_idx =
2456 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2457 make_multi_index(n_thread_data_on_block));
2458
2459 // shuffle: threadwise copy C from VGPR to LDS
2460 auto c_thread_copy_vgpr_to_lds =
2462 CShuffleDataType,
2463 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2464 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2466 Sequence<CShuffleMXdlPerWavePerShuffle,
2467 CShuffleNXdlPerWavePerShuffle,
2468 I1,
2469 I1,
2470 I1,
2471 N2,
2472 I1,
2473 N4>,
2475 7,
2476 1,
2478 1,
2479 true>{
2480 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2482 0,
2483 m_thread_data_on_block_idx[I1],
2484 n_thread_data_on_block_idx[I1],
2485 m_thread_data_on_block_idx[I2],
2486 n_thread_data_on_block_idx[I2],
2487 n_thread_data_on_block_idx[I3],
2488 n_thread_data_on_block_idx[I4]),
2490
2491 using EDataType = CDataType;
2492
2493 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2494 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2495
2496 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2498 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2499
2500 const auto ds_grid_buf = generate_tuple(
2501 [&](auto i) {
2503 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2504 },
2506
2507 // tuple of reference to C/Ds tensor descriptors
2508 const auto c_ds_desc_refs = concat_tuple_of_reference(
2509 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2510 generate_tie([&](auto i) -> const auto& // return type should be reference
2511 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2513
2514 // tuple of reference to C/Ds tensor descriptors
2515 const auto c_ds_buf_refs = concat_tuple_of_reference(
2516 tie(c_shuffle_block_buf),
2517 generate_tie([&](auto i) -> const auto& // return type should be reference
2518 { return ds_grid_buf[i]; },
2520
2521 // tuple of starting index of C/Ds blockwise copy
2522 const auto idx_c_ds_block_begin =
2525 [&](auto) {
2526 return make_multi_index(block_m_id, 0, block_n_id, 0);
2527 // return make_multi_index(block_work_idx[I0], 0,
2528 // block_work_idx[I1], 0);
2529 },
2531
2532 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2533 c_grid_desc_mblock_mperblock_nblock_nperblock;
2534
2535 using CDEBlockTransferCluster =
2536 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2537 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2538 constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2539 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2541 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2543 decltype(c_ds_desc_refs),
2544 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2545 CElementwiseOperation,
2546 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2547 // support arbitray type
2548 Sequence<1,
2549 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2550 1,
2551 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2552 CDEBlockTransferCluster,
2553 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2554 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2555 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2556 3, // index_t SrcVectorDim,
2557 3, // index_t DstVectorDim,
2558 CDEShuffleBlockTransferScalarPerVectors,
2563 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2564 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2565 IndexType,
2566 1, // ScatterDim
2567 true, // OutputScatter: false, only use scatter weights
2568 scatter_weight_idx // ScatterWeightIdx: ascale
2569 >{c_ds_desc_refs,
2570 idx_c_ds_block_begin,
2571 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2572 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2573 c_element_op};
2574
2576 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2577 // space filling curve for threadwise C in VGPR
2578 constexpr auto sfc_c_vgpr =
2581 Sequence<CShuffleMXdlPerWavePerShuffle,
2582 CShuffleNXdlPerWavePerShuffle,
2583 1,
2584 1,
2585 1,
2586 N2,
2587 1,
2588 N4>>{};
2589
2590 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2591
2592 // space filling curve for shuffled blockwise C/D/E
2593 constexpr auto sfc_cde_block =
2596 Sequence<1,
2597 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2598 1,
2599 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2600
2601 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2602 constexpr auto EMThreads =
2603 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2604 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2605 constexpr auto ENThreads =
2606 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2607 static_for<0, num_access, 1>{}([&](auto access_id) {
2608 // make sure it's safe to write to LDS
2610 scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2611
2612 auto dstidx = sfc_cde_block.GetIndex(access_id);
2613 const index_t c_token_pos =
2614 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2615 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2616 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2617 index_t token_offset = fused_token & 0xffffff;
2618 if constexpr(IsInputGemm)
2619 {
2620 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2621 }
2622 scatter_offsets(m0) = token_offset * problem.N;
2623 });
2624
2626
2627 // each thread write its data from VGPR to LDS
2628 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2629 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2630 c_thread_buf,
2631 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2632 c_shuffle_block_buf);
2633
2634 // make sure it's safe to read from LDS
2636
2637 // each block copy its data from LDS to global
2638 cde_block_copy_lds_and_global.Run(
2639 c_ds_desc_refs,
2640 c_ds_buf_refs,
2641 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2642 tie(c_grid_buf),
2643 scatter_offsets);
2644
2645 if constexpr(access_id < num_access - 1)
2646 {
2647 constexpr auto cde_lds_and_global_step =
2648 sfc_cde_block.GetForwardStep(access_id);
2649
2650 // move on Ds
2651 static_for<0, NumDTensor, 1>{}([&](auto i) {
2652 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2653 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2654 });
2655
2656 // move on E
2657 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2658 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2659 I0,
2660 cde_lds_and_global_step);
2661 }
2662 });
2663 }
2664 }
2665};
2666
2667} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:46
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
Activation
Definition gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition gridwise_moe_gemm.hpp:32
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:84
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_moe_gemm_blockscale.hpp:666
const index_t * p_sorted_token_ids
Definition gridwise_moe_gemm_blockscale.hpp:722
CDataType * p_c_grid
Definition gridwise_moe_gemm_blockscale.hpp:728
const BScaleType * p_b_scale_grid
Definition gridwise_moe_gemm_blockscale.hpp:731
const index_t * p_max_token_id
Definition gridwise_moe_gemm_blockscale.hpp:724
DsGridPointer p_ds_grid
Definition gridwise_moe_gemm_blockscale.hpp:727
const CElementwiseOperation c_element_op
Definition gridwise_moe_gemm_blockscale.hpp:735
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_moe_gemm_blockscale.hpp:667
const ADataType * p_a_grid
Definition gridwise_moe_gemm_blockscale.hpp:725
const AElementwiseOperation a_element_op
Definition gridwise_moe_gemm_blockscale.hpp:733
const index_t * p_sorted_expert_ids
Definition gridwise_moe_gemm_blockscale.hpp:723
const BDataType * p_b_grid
Definition gridwise_moe_gemm_blockscale.hpp:726
const AScaleType * p_a_scale_grid
Definition gridwise_moe_gemm_blockscale.hpp:730
const BElementwiseOperation b_element_op
Definition gridwise_moe_gemm_blockscale.hpp:734
index_t K
Definition gridwise_moe_gemm_blockscale.hpp:648
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_moe_gemm_blockscale.hpp:602
index_t TopK
Definition gridwise_moe_gemm_blockscale.hpp:645
index_t NPadded
Definition gridwise_moe_gemm_blockscale.hpp:655
index_t StrideB
Definition gridwise_moe_gemm_blockscale.hpp:650
__host__ void Print() const
Definition gridwise_moe_gemm_blockscale.hpp:633
index_t BK0
Definition gridwise_moe_gemm_blockscale.hpp:659
index_t KRead
Definition gridwise_moe_gemm_blockscale.hpp:656
index_t N
Definition gridwise_moe_gemm_blockscale.hpp:647
index_t StrideC
Definition gridwise_moe_gemm_blockscale.hpp:652
index_t KBatch
Definition gridwise_moe_gemm_blockscale.hpp:653
index_t MBlock
Definition gridwise_moe_gemm_blockscale.hpp:660
index_t KPadded
Definition gridwise_moe_gemm_blockscale.hpp:657
index_t NumTokens
Definition gridwise_moe_gemm_blockscale.hpp:644
index_t StrideA
Definition gridwise_moe_gemm_blockscale.hpp:649
index_t AK0
Definition gridwise_moe_gemm_blockscale.hpp:658
index_t M
Definition gridwise_moe_gemm_blockscale.hpp:646
index_t MPadded
Definition gridwise_moe_gemm_blockscale.hpp:654
index_t NBlock
Definition gridwise_moe_gemm_blockscale.hpp:661
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_moe_gemm_blockscale.hpp:651
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_moe_gemm_blockscale.hpp:740
index_t a_k_split_offset
Definition gridwise_moe_gemm_blockscale.hpp:771
index_t b_k_split_offset
Definition gridwise_moe_gemm_blockscale.hpp:772
Definition gridwise_moe_gemm_blockscale.hpp:177
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_gemm_blockscale.hpp:1925
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition gridwise_moe_gemm_blockscale.hpp:914
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_moe_gemm_blockscale.hpp:1156
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_gemm_blockscale.hpp:1177
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition xdlops_gemm.hpp:1804
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition threadwise_tensor_slice_transfer.hpp:39
Definition threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1041
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087