gridwise_moe_mx_gemm.hpp Source File

gridwise_moe_mx_gemm.hpp Source File#

Composable Kernel: gridwise_moe_mx_gemm.hpp Source File
gridwise_moe_mx_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
16
20
21#define DEBUG_LOG 0
22
23namespace ck {
24
25// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26// kernel function Blockers:
27// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28// two lds chunks.
29// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30// buffer when we declare __shared__ inside blkgemmpipe
31
33{
34 gelu_and_mul = 0,
35 silu_and_mul = 1
36};
37
38#if 0
39template <typename GridwiseGemm,
40 bool HasMainKBlockLoop,
41 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
42 index_t MinimumOccupancy = 1,
44__global__ void
45#if CK_USE_LAUNCH_BOUNDS
46__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
47#endif
48 // __attribute__((amdgpu_waves_per_eu(1, 1)))
49 kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
50{
51#if defined(__gfx9__)
52 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
53 {
54 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
55
56 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
57
58 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
59 karg.p_sorted_token_ids,
60 karg.p_sorted_expert_ids,
61 karg.p_max_token_id,
62 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
63 karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
64 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
65 karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
66 karg.p_ds_grid,
67 karg.p_c_grid,
68 p_shared,
69 karg,
70 karg.a_element_op,
71 karg.b_element_op,
72 karg.c_element_op);
73 }
74#else
75 ignore = karg;
76#endif // end of if (defined(__gfx9__))
77}
78#endif
79
80template <typename GridwiseGemm,
81 bool HasMainKBlockLoop,
82 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
83 index_t MinimumOccupancy = 1,
85__global__ void
86#if CK_USE_LAUNCH_BOUNDS
87__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
88#endif
89 // __attribute__((amdgpu_waves_per_eu(1, 1)))
90 kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
91{
92#if defined(__gfx9__)
93 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
94 {
95 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
97
98 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
99
100 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
101 karg.p_sorted_token_ids,
102 karg.p_sorted_expert_ids,
103 karg.p_max_token_id,
104 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
105 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
106 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
107 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
108 karg.p_ds_grid,
109 karg.p_c_grid,
110 p_shared_0,
111 p_shared_1,
112 karg,
113 karg.a_element_op,
114 karg.b_element_op,
115 karg.c_element_op);
116 }
117#else
118 ignore = karg;
119#endif // end of if (defined(__gfx9__))
120}
121
122template <typename ALayout,
123 typename BLayout,
124 typename DsLayout,
125 typename CLayout,
126 typename ADataType,
127 typename AScaleDataType,
128 typename BDataType,
129 typename BScaleDataType,
130 typename AccDataType,
131 typename CShuffleDataType,
132 typename DsDataType,
133 typename CDataType,
134 typename AElementwiseOperation,
135 typename BElementwiseOperation,
136 typename CElementwiseOperation,
138 index_t ScaleBlockSize, // Scaling block size
139 index_t BlockSize, // Thread block size
140 index_t MPerBlock,
141 index_t NPerBlock,
142 index_t KPerBlock,
143 index_t AK1Value,
144 index_t BK1Value,
145 index_t MPerXdl,
146 index_t NPerXdl,
147 index_t MXdlPerWave,
148 index_t NXdlPerWave,
149 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
150 typename ABlockTransferThreadClusterArrangeOrder,
151 typename ABlockTransferSrcAccessOrder,
152 index_t ABlockTransferSrcVectorDim,
153 index_t ABlockTransferSrcScalarPerVector,
154 index_t ABlockTransferDstScalarPerVector_AK1,
155 bool AThreadTransferSrcResetCoordinateAfterRun,
156 index_t ABlockLdsExtraM,
157 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
158 typename BBlockTransferThreadClusterArrangeOrder,
159 typename BBlockTransferSrcAccessOrder,
160 index_t BBlockTransferSrcVectorDim,
161 index_t BBlockTransferSrcScalarPerVector,
162 index_t BBlockTransferDstScalarPerVector_BK1,
163 bool BThreadTransferSrcResetCoordinateAfterRun,
164 index_t BBlockLdsExtraN,
165 index_t CShuffleMXdlPerWavePerShuffle,
166 index_t CShuffleNXdlPerWavePerShuffle,
167 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
168 typename CDEShuffleBlockTransferScalarPerVectors,
171 index_t ActivationOperation = 0,
172 bool NSwizzle = false,
173 bool IsInputGemm = true,
174 bool MulRoutedWeight = true,
175 typename IndexType = index_t,
176 typename ComputeTypeA = ADataType,
177 typename ComputeTypeB = BDataType>
179{
180 using LDSTypeA = ADataType;
181 using LDSTypeB = BDataType;
182
183 static constexpr auto I0 = Number<0>{};
184 static constexpr auto I1 = Number<1>{};
185 static constexpr auto I2 = Number<2>{};
186 static constexpr auto I3 = Number<3>{};
187 static constexpr auto I4 = Number<4>{};
188 static constexpr auto I5 = Number<5>{};
189 static constexpr auto I6 = Number<6>{};
190 static constexpr auto I7 = Number<7>{};
191 static constexpr auto I8 = Number<8>{};
192 static constexpr auto I9 = Number<9>{};
193
195 CDEShuffleBlockTransferScalarPerVectors{}[I0];
196 // K1 should be Number<...>
197 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
198 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
199 static constexpr auto AK1Number = Number<AK1Value>{};
200 static constexpr auto BK1Number = Number<BK1Value>{};
201
202 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
203 static constexpr bool is_single_rate_mfma = false;
204 static constexpr auto is_scale_mfma = true;
205
206 static constexpr index_t NumDTensor = DsDataType::Size();
207
208 static constexpr auto MXdlPack = 2;
209 static constexpr auto NXdlPack = 2;
210 static constexpr auto KXdlPack = 2;
211
212 //> KPack is at least the k_per_blk of selected mfma
213 //
214 // Should be a multiple of k_per_blk.
215 // TODO: Move this to blockwise pipeline base
216 // KPack in packed data types for pk A/B
217
220
221 using mfma_selector = MfmaSelector<ComputeTypeA,
222 MPerXdl,
223 NPerXdl,
224 ComputeTypeB,
227 static constexpr index_t KPack =
229
230 // static constexpr index_t NumTokens = 1;
231 static constexpr index_t SortedTileSize = MPerBlock;
232
233 static constexpr auto MakeDsGridPointer()
234 {
235 return generate_tuple(
236 [&](auto i) {
237 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
238
239 return static_cast<const DDataType*>(nullptr);
240 },
242 }
243
244 using DsGridPointer = decltype(MakeDsGridPointer());
245
247
248 __host__ static auto CalculateGridSize(index_t M, index_t N)
249 {
250 const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
251 const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
252 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
253 const index_t gridy = NSwizzle ? 1 : mblock;
254
255 return std::make_tuple(gridx, gridy, 1);
256 }
257
258 __host__ static auto CalculateMPadded(index_t M)
259 {
260 return math::integer_least_multiple(M, MPerBlock);
261 }
262
263 __host__ static auto CalculateNPadded(index_t N)
264 {
265 return math::integer_least_multiple(N, NPerBlock);
266 }
267
268 __host__ static auto CalculateKPadded(index_t K)
269 {
270 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
271 }
272
273 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
274 {
275 auto K_t = K_Batch * KPerBlock;
276 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
277 }
278
279 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
280 {
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
283 }
284
285 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
286 {
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * KPerBlock;
289 }
290
291 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
292 {
293 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
294 auto K_t = K_Batch * KReadVec;
295 return (K + K_t - 1) / K_t * KReadVec;
296 }
297
298 __host__ static auto CalculateMBlock(index_t M)
299 {
300 return math::integer_divide_ceil(M, MPerBlock);
301 }
302
303 __host__ static auto CalculateNBlock(index_t N)
304 {
305 return math::integer_divide_ceil(N, NPerBlock);
306 }
307
308 template <index_t MNXdlPerWave,
309 index_t MNWaves,
310 index_t MNXdlPack,
311 index_t MNPerXdl,
312 typename TileDesc_K0_MN_K1>
313 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
314 {
315 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
316 constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
317 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
318
319 constexpr auto permuted_desc = transform_tensor_descriptor(
320 TileDesc_K0_MN_K1{},
325
327 permuted_desc,
332 Number<MNPerXdl>{}))),
335 }
336
337 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
338 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
339 {
340 const auto a_grid_desc_mraw_kraw = [&]() {
342 {
343 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
344 }
346 {
347 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
348 }
349 }();
350
351 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
352
353 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
354 GemmSpec == GemmSpecialization::MNKPadding)
355 {
356 // pad both M and K
357 const auto a_grid_desc_m_k =
358 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
360 make_right_pad_transform(K, KPad - K)),
363
364 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
365 a_grid_desc_m_k,
370
371 return a_grid_desc_ak0_m_ak1;
372 }
373 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
374 GemmSpec == GemmSpecialization::MNPadding)
375 {
376 // pad M, but not K
377 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
378 a_grid_desc_mraw_kraw,
379 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
380 make_right_pad_transform(M, MPad - M)),
383
384 const auto a_grid_desc_permuted = transform_tensor_descriptor(
385 a_grid_desc_ak0_m_ak1,
391
392 const auto a_grid_desc = transform_tensor_descriptor(
393 a_grid_desc_permuted,
400 return a_grid_desc;
401 }
402 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
403 GemmSpec == GemmSpecialization::NKPadding)
404 {
405 // pad K, but not M
406 const auto a_grid_desc_m_k = transform_tensor_descriptor(
407 a_grid_desc_mraw_kraw,
411
412 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
413 a_grid_desc_m_k,
418
419 return a_grid_desc_ak0_m_ak1;
420 }
421 else
422 {
423 // not pad M or K
424 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
425 a_grid_desc_mraw_kraw,
426 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
430
431 const auto a_grid_desc_permuted = transform_tensor_descriptor(
432 a_grid_desc_ak0_m_ak1,
438
439 const auto a_grid_desc = transform_tensor_descriptor(
440 a_grid_desc_permuted,
447
448 return a_grid_desc;
449 }
450 }
451
452 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
453 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
454 {
455 const auto b_grid_desc_nraw_kraw = [&]() {
457 {
458 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
459 }
461 {
462 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
463 }
464 }();
465
466 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
467
469 GemmSpec != GemmSpecialization::Default),
470 "pk_i4_t does not support padding");
472 (GemmSpec != GemmSpecialization::Default &&
473 GemmSpec != GemmSpecialization::MPadding)),
474 "f4x2_pk_t does not support K padding");
475
476 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
477 GemmSpec == GemmSpecialization::MNKPadding)
478 {
479 // pad both N and K
480 const auto b_grid_desc_n_k =
481 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
483 make_right_pad_transform(K, KPad - K)),
486
487 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
488 b_grid_desc_n_k,
493
494 return b_grid_desc_bk0_n_bk1;
495 }
496 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
497 GemmSpec == GemmSpecialization::MNPadding)
498 {
499 // pad N, but not K
500 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
501 b_grid_desc_nraw_kraw,
503 make_right_pad_transform(N, NPad - N)),
506
507 return b_grid_desc_bk0_n_bk1;
508 }
509 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
510 GemmSpec == GemmSpecialization::MKPadding)
511 {
512 // pad K, but not N
513 const auto b_grid_desc_n_k = transform_tensor_descriptor(
514 b_grid_desc_nraw_kraw,
518
519 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
520 b_grid_desc_n_k,
525
526 return b_grid_desc_bk0_n_bk1;
527 }
528 else
529 {
530 // not pad N or K
531 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
532 b_grid_desc_nraw_kraw,
533 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
537
538 const auto b_grid_desc_permuted = transform_tensor_descriptor(
539 b_grid_desc_bk0_n_bk1,
545
546 const auto b_grid_desc = transform_tensor_descriptor(
547 b_grid_desc_permuted,
554
555 return b_grid_desc;
556 }
557 }
558
559 template <typename ABlockDesc_AK0_M_AK1>
560 __host__ __device__ static constexpr auto
561 MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
562 {
563 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
564
566 ABlockDesc_AK0_M_AK1{});
567 }
568
569 template <typename BBlockDesc_BK0_N_BK1>
570 __host__ __device__ static constexpr auto
571 MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
572 {
573 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
574
576 BBlockDesc_BK0_N_BK1{});
577 }
578
579 template <typename ELayout>
580 __host__ __device__ static auto MakeCGridDescriptor_M_N(
581 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
582 {
583 const auto c_grid_desc_mraw_nraw = [&]() {
585 {
586 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
587 }
589 {
590 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
591 }
592 }();
593
594 // pad M and N
595 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
597 make_right_pad_transform(N, NPad - N)),
600 }
601
602 template <typename DLayout>
603 __host__ __device__ static auto
605 {
606 const auto c_grid_desc_mraw_nraw = [&]() {
608 {
609 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
610 }
612 {
613 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
614 }
615 }();
616
617 // pad M and N
618 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
620 make_right_pad_transform(N, NPad - N)),
623 }
624
625 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
626 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
627 {
628 return generate_tuple(
629 [&](auto i) {
630 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
631 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
632 },
634 }
635
636 template <typename DsGridDesc>
638 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
639 {
640 return generate_tuple(
641 [&](auto i) {
643 ds_grid_desc_m_n[i], MBlock, NBlock);
644 },
646 }
647
648 struct Problem
649 {
650 __host__ Problem(index_t NumTokens_,
651 index_t TopK_,
652 index_t M_,
653 index_t N_,
654 index_t K_,
655 index_t StrideA_,
656 index_t StrideScaleA_,
657 index_t StrideB_,
658 index_t StrideScaleB_,
659 std::array<index_t, NumDTensor> StrideDs_,
660 index_t StrideC_,
661 index_t KBatch_)
662 : NumTokens{NumTokens_},
663 TopK{TopK_},
664 M{M_},
665 N{N_},
666 K{K_},
667 StrideA{StrideA_},
668 StrideScaleA{StrideScaleA_},
669 StrideB{StrideB_},
670 StrideScaleB{StrideScaleB_},
671 StrideDs{StrideDs_},
672 StrideC{StrideC_},
673 KBatch{KBatch_},
676 KRead{CalculateKRead(K_, KBatch_)},
677 KPadded{CalculateKPadded(K_, KBatch_)},
678 AK0{CalculateAK0Padded(K_, KBatch_)},
679 BK0{CalculateBK0Padded(K_, KBatch_)},
682 {
683 }
684
685 __host__ void Print() const
686 {
687 std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
688 << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
689 << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
690 << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
691 << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
692 << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
693 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
694 << ", " << "NBlock: " << NBlock << "}" << std::endl;
695 }
696
706 std::array<index_t, NumDTensor> StrideDs;
717 };
718
719 // Argument
721 {
722 __host__ Argument(const index_t* p_sorted_token_ids_,
723 const index_t* p_sorted_expert_ids_,
724 const index_t* p_max_token_id_,
725 const ADataType* p_a_grid_,
726 const AScaleDataType* p_a_scale_grid_,
727 const BDataType* p_b_grid_,
728 const BScaleDataType* p_b_scale_grid_,
729 std::array<const void*, NumDTensor> p_ds_grid_,
730 CDataType* p_c_grid_,
731 index_t NumTokens_,
732 index_t TopK_,
733 index_t M_,
734 index_t N_,
735 index_t K_,
736 index_t StrideA_,
737 index_t StrideScaleA_,
738 index_t StrideB_,
739 index_t StrideScaleB_,
740 std::array<index_t, NumDTensor> StrideDs_,
741 index_t StrideC_,
742 index_t k_batch_,
743 AElementwiseOperation a_element_op_,
744 BElementwiseOperation b_element_op_,
745 CElementwiseOperation c_element_op_)
746 : Problem{NumTokens_,
747 TopK_,
748 M_,
749 N_,
750 K_ / APackedSize,
751 StrideA_ / APackedSize,
752 StrideScaleA_,
753 StrideB_ / BPackedSize,
754 StrideScaleB_,
755 StrideDs_,
756 StrideC_,
757 k_batch_},
758 p_sorted_token_ids{p_sorted_token_ids_},
759 p_sorted_expert_ids{p_sorted_expert_ids_},
760 p_max_token_id{p_max_token_id_},
761 p_a_grid{p_a_grid_},
762 p_a_scale_grid{p_a_scale_grid_},
763 p_b_grid{p_b_grid_},
764 p_b_scale_grid{p_b_scale_grid_},
765 p_ds_grid{},
766 p_c_grid{p_c_grid_},
767 a_element_op{a_element_op_},
768 b_element_op{b_element_op_},
769 c_element_op{c_element_op_}
770 {
771
772 // populate pointer, desc for Ds
773 static_for<0, NumDTensor, 1>{}([&](auto i) {
774 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
775
776 // D pointer
777 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
778 });
779 }
780
784 const ADataType* p_a_grid;
785 const AScaleDataType* p_a_scale_grid;
786 const BDataType* p_b_grid;
787 const BScaleDataType* p_b_scale_grid;
789 CDataType* p_c_grid;
790
791 const AElementwiseOperation a_element_op;
792 const BElementwiseOperation b_element_op;
793 const CElementwiseOperation c_element_op;
794 };
795
797 {
798 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
799 {
801 {
802 a_k_split_offset = k_id * karg.KRead;
803 }
805 {
806 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
807 }
808
810 {
811 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
812 }
814 {
815 // KPack * NLane * KLane * K0 * N0
816 b_k_split_offset = k_id * karg.KRead;
817 }
818
819 // Calculate A scale offset
821 {
822 a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
823 }
825 {
827 k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
828 }
829
830 // Calculate B scale offset
832 {
834 k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
835 }
837 {
838 b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
839 }
840
841 if(k_id < karg.KBatch - 1)
842 {
843 karg.K = karg.KRead;
844 }
845 else
846 {
847 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
848 }
849 }
850
855 };
856
857 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
858 {
859 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
860 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
861 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
862
863 // A matrix in LDS memory, dst of blockwise copy
864 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
865 {
866 // contiguous in LDS
870 }
871 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
872 // in some cases.
874 {
875 constexpr auto a_lds_block_desc =
878
879 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
880 a_lds_block_desc,
886
887 return a_lds_block_desc_permuted;
888 }
889 else // ColumnMajor A
890 {
891 // kfold and mpair dimension is not always required.
892 // more dimension in merge_transform increase the difficulty of generating immarg offset
893 // for compiler.
894 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
895 constexpr auto M1 = MPerBlock / M0;
896
897 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
898 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
899 constexpr auto KThreadRead = WaveSize / MPerXdl;
900 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
901
902 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
903 ? 1
904 : 128 / (AK1Number * M0 * sizeof(ADataType));
905 constexpr auto KThreadReadPerm =
906 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
907 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
908 : KThreadRead;
909
910 // 1<=mpair<=n0
911 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
912 ? 1
913 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
914 ? M0
915 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
916
917 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
921 Number<kfold * M0 / mpair>{},
923 AK1Number));
924
925 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
926 a_lds_block_desc,
931 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
938
939 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
940 a_lds_block_desc_permuted,
949 Sequence<1>{},
950 Sequence<2>{},
951 Sequence<3>{},
952 Sequence<4>{},
953 Sequence<5>{}),
955 Sequence<2>{},
958 Sequence<6>{},
959 Sequence<7>{}));
960
961 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
962 a_lds_block_desc_unmerged,
965 Number<KThreadWrite / kfold / KThreadReadPerm>{},
973
974 return a_lds_block_desc_ak0_m_ak1;
975 }
976 }
977
978 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
979 {
980 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
981 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
982 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
983
984 // B matrix in LDS memory, dst of blockwise copy
985 if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
986 {
987 // contiguous in lds
991 }
993 {
994 // NLdsLayer * K0 as logical Bank
995 constexpr auto b_lds_block_desc =
998
999 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1000 b_lds_block_desc,
1006
1007 return b_lds_block_desc_permuted;
1008 }
1009 else // RowMajor B
1010 {
1011 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1012 constexpr auto N1 = NPerBlock / N0;
1013
1014 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1015 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1016 constexpr auto KThreadRead = WaveSize / NPerXdl;
1017 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1018
1019 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1020 ? 1
1021 : 128 / (BK1Number * N0 * sizeof(BDataType));
1022 constexpr auto KThreadReadPerm =
1023 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1024 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1025 : KThreadRead;
1026
1027 // 1<=npair<=n0
1028 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1029 ? 1
1030 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1031 ? N0
1032 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1033
1034 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1038 Number<kfold * N0 / npair>{},
1039 Number<npair>{},
1040 BK1Number));
1041
1042 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1043 b_lds_block_desc,
1044 make_tuple(
1048 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1051 make_tuple(
1053 make_tuple(
1055
1056 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1057 b_lds_block_desc_permuted,
1058 make_tuple(
1066 Sequence<1>{},
1067 Sequence<2>{},
1068 Sequence<3>{},
1069 Sequence<4>{},
1070 Sequence<5>{}),
1072 Sequence<2>{},
1075 Sequence<6>{},
1076 Sequence<7>{}));
1077
1078 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1079 b_lds_block_desc_unmerged,
1082 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1083 Number<kfold>{},
1090
1091 return b_lds_block_desc_bk0_n_bk1;
1092 }
1093 }
1094
1096 {
1097 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1098 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1099
1100 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1102 make_tuple(I1,
1104 I1,
1106
1107 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1108 }
1109
1112 BlkGemmPipelineVer,
1113 BlkGemmPipeSched,
1114 BlockSize,
1115 ScaleBlockSize,
1116 ADataType,
1117 AScaleDataType,
1118 BDataType,
1119 BScaleDataType,
1120 ComputeTypeA,
1121 AccDataType,
1128 ABlockTransferSrcScalarPerVector,
1129 BBlockTransferSrcScalarPerVector,
1130 MPerBlock,
1131 NPerBlock,
1132 KPerBlock,
1133 MPerXdl,
1134 NPerXdl,
1135 MXdlPerWave,
1136 NXdlPerWave,
1137 KPack,
1138 IsInputGemm>())>;
1139
1140 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1141 {
1142 // LDS allocation for A and B: be careful of alignment
1143 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1144 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1145
1146 // lds max alignment
1147 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1148
1149 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1150 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1151
1152 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1153 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1154
1155 // LDS allocation for C shuffle in LDS
1156 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1158
1159 constexpr auto c_block_size =
1160 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1161
1162 if constexpr(IsInputGemm)
1163 {
1164 return math::max(a_block_space_size_aligned * sizeof(ADataType) +
1165 b_block_space_size_aligned * sizeof(BDataType) * 2,
1166 c_block_size * sizeof(CShuffleDataType));
1167 }
1168 else
1169 {
1170 return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1171 b_block_space_size_aligned * sizeof(BDataType)),
1172 c_block_size * sizeof(CShuffleDataType));
1173 }
1174 }
1175
1177
1178 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1179 __host__ static constexpr bool CheckValidity(const Argument& karg)
1180 {
1181 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1182 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1183 "Invalid tuning param!");
1184
1185 static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1186 "KPerBlock should be multiple of ScaleBlockSize");
1187
1188 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1193 {
1194 if(!(karg.M % MPerBlock == 0))
1195 {
1196 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1197 {
1198 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1199 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1200 << std::endl;
1201 }
1202 return false;
1203 }
1204 }
1205
1206 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1211 {
1212 if(!(karg.N % NPerBlock == 0))
1213 {
1214 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1215 {
1216 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1217 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1218 << std::endl;
1219 }
1220 return false;
1221 }
1222 }
1223
1224 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1228 {
1229 auto K_t = karg.KBatch * KPerBlock;
1230 if(!(karg.K % K_t == 0))
1231 {
1232 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1233 {
1234 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1235 << karg.K << " " << __FILE__ << ":" << __LINE__
1236 << ", in function: " << __func__ << std::endl;
1237 }
1238 return false;
1239 }
1240 }
1241 else
1242 {
1243 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1244 auto K_t = karg.KBatch * KReadVec;
1245 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1246 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1247 {
1248 return false;
1249 }
1250 }
1251
1253 {
1254 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1255 {
1256 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1257 {
1258 std::cout << "Arg K (" << karg.K
1259 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1260 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1261 << __LINE__ << ", in function: " << __func__ << std::endl;
1262 }
1263 return false;
1264 }
1265 }
1266 else
1267 {
1268 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1269 {
1270 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1271 {
1272 std::cout << "Arg M (" << karg.M
1273 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1274 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1275 << __LINE__ << ", in function: " << __func__ << std::endl;
1276 }
1277 return false;
1278 }
1279 }
1280
1282 {
1283 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1284 {
1285 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1286 {
1287 std::cout << "Arg N (" << karg.N
1288 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1289 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1290 << __LINE__ << ", in function: " << __func__ << std::endl;
1291 }
1292 return false;
1293 }
1294 }
1295 else
1296 {
1297 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1298 {
1299 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1300 {
1301 std::cout << "Arg K (" << karg.K
1302 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1303 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1304 << __LINE__ << ", in function: " << __func__ << std::endl;
1305 }
1306 return false;
1307 }
1308 }
1309
1311 {
1313 {
1314 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1315 {
1316 std::cout << "Arg N (" << karg.N
1317 << ") value is not a multiple of "
1318 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1320 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1321 << std::endl;
1322 }
1323 return false;
1324 }
1325 }
1326 else
1327 {
1329 {
1330 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1331 {
1332 std::cout << "Arg M (" << karg.M
1333 << ") value is not a multiple of "
1334 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1336 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1337 << std::endl;
1338
1339 return false;
1340 }
1341 }
1342 }
1343
1344 // check gridwise gemm pipeline
1345#if 0
1346 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1347
1348 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1349 {
1350 return false;
1351 }
1352#endif
1353 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1354 return true;
1355 }
1356
1357 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1358 {
1359 const index_t num_loop = K / KPerBlock;
1360
1361 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1362 }
1363
1364 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1365 {
1366 const index_t num_loop = K / KPerBlock;
1367
1368 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1369 }
1370
1371 template <typename CGridDesc>
1372 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1373 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1374 {
1375 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1376 c_grid_desc_m_n,
1381
1382 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1383 }
1384
1385 // return block_id to C matrix tile idx (m0, n0) mapping
1386 // if arch = gfx942
1387 // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1388 // NPerBlock>;
1389
1391 static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1392 static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1393 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1394 "A scale pack data type too large!");
1395 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1396 "B scale pack data type too large!");
1397
1400 "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1401
1402#if 0
1403 template <bool HasMainKBlockLoop,
1404 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1405 TailNumber TailNum = TailNumber::Odd>
1406 __device__ static void Run(const index_t* p_sorted_token_ids,
1407 const index_t* p_sorted_expert_ids,
1408 const index_t* p_max_token_id,
1409 const ADataType* p_a_grid,
1410 const AScaleDataType* p_a_scale_grid,
1411 const BDataType* p_b_grid,
1412 const BScaleDataType* p_b_scale_grid,
1413 DsGridPointer& p_ds_grid,
1414 CDataType* p_c_grid,
1415 void* p_shared,
1416 const Problem& problem,
1417 AElementwiseOperation a_element_op,
1418 BElementwiseOperation b_element_op,
1419 CElementwiseOperation c_element_op)
1420 {
1421 ignore = a_element_op;
1422 ignore = b_element_op;
1423 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1424 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1425 problem.MPadded,
1426 problem.K,
1427 problem.KPadded,
1428 problem.StrideA,
1429 problem.AK0);
1430 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1431 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1432 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1433 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1434 problem.MPadded,
1435 problem.N,
1436 problem.NPadded,
1437 problem.StrideC);
1438
1439 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1440 make_tuple(problem.M / (MXdlPack * MPerXdl),
1441 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1442 (KXdlPack * 64 / MPerXdl),
1444
1445 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1446 make_tuple(problem.N / (NXdlPack * NPerXdl),
1447 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1448 (KXdlPack * 64 / NPerXdl),
1450
1451 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1453 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1454
1455 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1456 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1457 if(expert_block_id * MPerBlock >= max_token_id)
1458 return;
1459 const index_t expert_id =
1460 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1461
1462 const auto block_mn = [&]() -> std::pair<int, int> {
1463 if constexpr(NSwizzle)
1464 {
1465 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1466 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1467 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1468 const index_t expert_swizzle =
1469 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1470 const index_t bid_new = blockIdx.x - prefix_block;
1471 const index_t nid = __builtin_amdgcn_readfirstlane(
1472 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1473 const index_t mid =
1474 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1475 return {nid, mid};
1476 }
1477 else
1478 {
1479 return {blockIdx.x, blockIdx.y};
1480 }
1481 }();
1482
1483 const index_t block_n_id = block_mn.first;
1484 const index_t block_m_id = block_mn.second;
1485 const index_t token0 =
1486 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1487
1488 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1489 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1490 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1491 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1492 constexpr auto AKThreads = AK0Threads * AK1Threads;
1493 constexpr auto AMRepeats = MPerBlock / AMThreads;
1494 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1495
1496 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1497 return;
1499 static_for<0, AMRepeats, 1>{}([&](auto m0) {
1500 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1501 index_t token_offset = fused_token & 0xffffff;
1502 if constexpr(!IsInputGemm)
1503 {
1504 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1505 }
1506 gather_offsets(m0) = static_cast<IndexType>(token_offset);
1507 });
1508
1509 const index_t expert_stride =
1510 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1511 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1512 problem.N * (IsInputGemm ? 2 : 1) *
1513 math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1514
1515 // N0, K0, Blocksize*KPack
1516 const index_t n_block_data_idx_on_grid =
1517 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1518
1519 // Gride buffer creation
1520 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1521 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1522 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1523 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1524
1525 // A, B scale buffer
1526 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1527 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1528 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1529 p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1530 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1531
1532 // lds max alignment
1533 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1534
1535 // A matrix in LDS memory, dst of blockwise copy
1536 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1537
1538 // B matrix in LDS memory, dst of blockwise copy
1539 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1540
1541 // A matrix blockwise direct to LDS copy
1545 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1546 ABlockTransferThreadClusterArrangeOrder,
1547 ADataType,
1548 ADataType,
1549 decltype(a_grid_desc_ak0_m_ak1),
1550 decltype(a_block_desc_ak0_m_ak1),
1551 ABlockTransferSrcAccessOrder,
1552 ABlockTransferSrcVectorDim,
1553 2,
1554 ABlockTransferSrcScalarPerVector,
1555 IndexType,
1556 1>(a_grid_desc_ak0_m_ak1,
1557 make_multi_index(0, 0, 0),
1558 a_block_desc_ak0_m_ak1,
1559 make_multi_index(0, 0, 0),
1560 gather_offsets);
1561
1562 // B matrix blockwise copy
1563 auto b_blockwise_copy =
1566 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1567 BBlockTransferThreadClusterArrangeOrder,
1568 BDataType,
1569 BDataType,
1570 decltype(b_grid_desc_bk0_n_bk1),
1571 decltype(b_block_desc_bk0_n_bk1),
1572 BBlockTransferSrcAccessOrder,
1573 BBlockTransferSrcVectorDim,
1574 2,
1575 BBlockTransferSrcScalarPerVector>(
1576 b_grid_desc_bk0_n_bk1,
1577 make_multi_index(0, n_block_data_idx_on_grid, 0),
1578 b_block_desc_bk0_n_bk1,
1579 make_multi_index(0, 0, 0));
1580
1581 // LDS allocation for A and B: be careful of alignment
1582 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1583 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1584
1585 // Cast after lds
1587 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1588
1590 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1591 a_block_space_size_aligned * sizeof(ADataType)),
1592 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1593
1594 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1595 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1596
1597 // Blockwise GEMM pipeline
1598 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1599 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1600 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1601 decltype(c_thread_buf) c_thread_buf_up;
1602
1604 float,
1605 c_thread_buf.num_of_v_,
1606 c_thread_buf.s_per_v,
1607 true>
1608 c_thread_buf_fp32;
1609
1610 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1611 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1612 KPerBlock);
1613
1614 // a and b scale processing
1615 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1616 const auto waveId_m = wave_idx[I0];
1617 const auto waveId_n = wave_idx[I1];
1618
1619 auto thread_offset_shuffled =
1620 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1621
1622 auto a_thread_offset_m = waveId_m;
1623
1624 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1625 AScaleDataType,
1626 AScaleDataType,
1627 decltype(a_scale_grid_desc_am_ak),
1628 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1629 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1630 Sequence<0, 1, 2>, // DimAccessOrder
1631 2, // SrcVectorDim
1632 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1633 1, // SrcScalarStrideInVector
1634 true>(a_scale_grid_desc_am_ak,
1635 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1636 0,
1637 thread_offset_shuffled / scale_pack_size_a));
1638
1639 // B scale load
1640 auto b_thread_offset_n = waveId_n;
1641
1642 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1643 BScaleDataType,
1644 BScaleDataType,
1645 decltype(b_scale_grid_desc_bn_ak),
1646 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1647 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1648 Sequence<0, 1, 2>, // DimAccessOrder
1649 2, // SrcVectorDim
1650 KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1651 1, // SrcScalarStrideInVector
1652 true>(b_scale_grid_desc_bn_ak,
1653 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1654 0,
1655 thread_offset_shuffled / scale_pack_size_b));
1656
1657 if constexpr(IsInputGemm)
1658 {
1659 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1660 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1661 auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1663 a_block_space_size_aligned * sizeof(ADataType) +
1664 b_block_space_size_aligned * sizeof(BDataType)),
1665 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1666
1667 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1668 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1669 p_b_grid_up + expert_id * expert_stride,
1670 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1671
1672 auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1675 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1676 BBlockTransferThreadClusterArrangeOrder,
1677 BDataType,
1678 BDataType,
1679 decltype(b_grid_desc_bk0_n_bk1),
1680 decltype(b_block_desc_bk0_n_bk1),
1681 BBlockTransferSrcAccessOrder,
1682 BBlockTransferSrcVectorDim,
1683 2,
1684 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1685 make_multi_index(0, n_block_data_idx_on_grid, 0),
1686 b_block_desc_bk0_n_bk1,
1687 make_multi_index(0, 0, 0));
1688
1689 const BScaleDataType* p_b_scale_grid_up =
1690 p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1691 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1692 p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1693 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1694
1695 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1696 BScaleDataType,
1697 BScaleDataType,
1698 decltype(b_scale_grid_desc_bn_ak),
1699 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1700 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1701 Sequence<0, 1, 2>, // DimAccessOrder
1702 2, // SrcVectorDim
1703 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1704 1, // SrcScalarStrideInVector
1705 true>(
1706 b_scale_grid_desc_bn_ak,
1707 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1708 0,
1709 thread_offset_shuffled / scale_pack_size_b));
1710
1711 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1712 // A
1713 a_grid_desc_ak0_m_ak1,
1714 a_block_desc_ak0_m_ak1,
1715 a_blockwise_copy,
1716 a_grid_buf,
1717 a_block_buf,
1718 a_block_slice_copy_step,
1719 // Gate and Up
1720 b_grid_desc_bk0_n_bk1,
1721 b_block_desc_bk0_n_bk1,
1722 b_blockwise_copy,
1723 b_blockwise_copy_up,
1724 b_grid_buf,
1725 b_grid_buf_up,
1726 b_block_buf,
1727 b_block_buf_up,
1728 b_block_slice_copy_step,
1729 // C
1730 c_thread_buf,
1731 c_thread_buf_up,
1732 // A scale
1733 a_scale_grid_desc_am_ak,
1734 a_scale_thread_copy,
1735 a_scale_grid_buf,
1736 // Gate and Up scale
1737 b_scale_grid_desc_bn_ak,
1738 b_scale_thread_copy,
1739 b_scale_thread_copy_up,
1740 b_scale_grid_buf,
1741 b_scale_grid_buf_up,
1742 num_k_block_main_loop);
1743 }
1744 else
1745 {
1746 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1747 a_grid_desc_ak0_m_ak1, // A
1748 a_block_desc_ak0_m_ak1,
1749 a_blockwise_copy,
1750 a_grid_buf,
1751 a_block_buf,
1752 a_block_slice_copy_step,
1753 b_grid_desc_bk0_n_bk1, // B
1754 b_block_desc_bk0_n_bk1,
1755 b_blockwise_copy,
1756 b_grid_buf,
1757 b_block_buf,
1758 b_block_slice_copy_step,
1759 c_thread_buf, // C
1760 a_scale_grid_desc_am_ak, // A scale
1761 a_scale_thread_copy,
1762 a_scale_grid_buf,
1763 b_scale_grid_desc_bn_ak, // B scale
1764 b_scale_thread_copy,
1765 b_scale_grid_buf,
1766 num_k_block_main_loop);
1767 }
1768
1769 // shuffle C and write out
1770 {
1771 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1772 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1773 "wrong!");
1774 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1775 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1776 "wrong!");
1777
1778 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1779 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1780
1781 // TODO: hacky, fix it!
1782 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1783 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1784
1785 // TODO: hacky, fix it!
1786 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1787 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1788 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1789
1790 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1791 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1792 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1793 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1794 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1795 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1796 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1797 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1798 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1799 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1800
1801 // mul scales
1802 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1803 static_assert(M5 == 4);
1804 const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1805 const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1806
1807 vector_type<float, 4> topk_weights; // for gemm2 only
1808 static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1809 static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1810 static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1811 static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1812 static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1813 const index_t m_pos = block_m_id * MPerBlock +
1814 m0 * M2 * M1 * M3 * M4 * M5 +
1815 m1 * M2 * M3 * M4 * M5 +
1816 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1817 if constexpr(MulRoutedWeight)
1818 {
1819 topk_weights =
1821 p_ds_grid[I2] + m_pos);
1822 }
1823 static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1824 constexpr index_t c_offset =
1825 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1826 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1827 constexpr auto cidx = Number<c_offset>{};
1828
1829 if constexpr(IsInputGemm) // gu fusion
1830 {
1831 if constexpr(ActivationOperation ==
1832 Activation::silu_and_mul)
1833 {
1834 float gate = c_thread_buf[cidx];
1835 float up = c_thread_buf_up[cidx];
1836 if constexpr(MulRoutedWeight)
1837 {
1838 gate = gate * topk_weights.AsType<float>()[m5];
1839 up = up * topk_weights.AsType<float>()[m5];
1840 }
1842 c_thread_buf_fp32(cidx) = gate * up;
1843 }
1844 else if(ActivationOperation == Activation::gelu_and_mul)
1845 {
1846 float gate = c_thread_buf[cidx];
1847 float up = c_thread_buf_up[cidx];
1848 if constexpr(MulRoutedWeight)
1849 {
1850 gate = gate * topk_weights.AsType<float>()[m5];
1851 up = up * topk_weights.AsType<float>()[m5];
1852 }
1854 c_thread_buf_fp32(cidx) = gate * up;
1855
1856 /*float gate = c_thread_buf[cidx];
1857 float up = c_thread_buf_up[cidx];
1858 if constexpr(MulRoutedWeight)
1859 {
1860 gate = gate * topk_weights.AsType<float>()[m5];
1861 //up = up * topk_weights.AsType<float>()[m5];
1862 }
1863 tensor_operation::element_wise::Gelu{}(gate, gate);
1864 c_thread_buf_fp32(cidx) = up;*/
1865 }
1866 }
1867 else
1868 {
1869 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1870 if constexpr(MulRoutedWeight)
1871 {
1872 c_thread_buf_fp32(cidx) =
1873 topk_weights.AsType<float>()[m5] *
1874 c_thread_buf_fp32[cidx];
1875 }
1876 }
1877 });
1878 });
1879 });
1880 });
1881 });
1882 });
1883
1884 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1886
1887 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1888 static_cast<CShuffleDataType*>(p_shared),
1889 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1890
1891 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1892 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1893 make_tuple(
1897 // per shuffle
1898 M1, // M1 = MWave
1899 M2, // M2 = MXdlPack
1900 M3, // M3 * M4 * M5 = MPerXdl
1901 M4,
1902 M5)),
1906 // per shuffle
1907 N1, // N1 = NWave
1908 N2, // N2 = NXdlPack
1909 N3))), // N3 = NPerXdl
1913 Sequence<>{},
1915
1916 // calculate origin of thread output tensor on global memory
1917 // blockwise GEMM c matrix starting index
1918 const auto c_thread_mtx_on_block =
1919 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1920
1921 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1922 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1923
1924 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1926 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1929
1930 const auto m_thread_data_on_block_idx =
1931 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1932 make_multi_index(m_thread_data_on_block));
1933
1934 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1936 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1939
1940 const auto n_thread_data_on_block_idx =
1941 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1942 make_multi_index(n_thread_data_on_block));
1943
1944 // shuffle: threadwise copy C from VGPR to LDS
1945 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1946 AccDataType,
1947 CShuffleDataType,
1948 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1949 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1951 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1952 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1953 I1,
1954 I1,
1955 M2,
1956 N2,
1957 M3,
1958 I1,
1959 M5,
1960 I1>,
1962 9,
1963 1,
1965 1,
1966 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1968 0,
1969 m_thread_data_on_block_idx[I1],
1970 n_thread_data_on_block_idx[I1],
1971 m_thread_data_on_block_idx[I2],
1972 n_thread_data_on_block_idx[I2],
1973 m_thread_data_on_block_idx[I3],
1974 m_thread_data_on_block_idx[I4],
1975 m_thread_data_on_block_idx[I5],
1976 n_thread_data_on_block_idx[I3]),
1978
1979 using EDataType = CDataType;
1980
1981 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1982 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1983
1984 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1986 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1987
1988 const auto ds_grid_buf = generate_tuple(
1989 [&](auto i) {
1991 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1992 },
1994
1995 // tuple of reference to C/Ds tensor descriptors
1996 const auto c_ds_desc_refs = concat_tuple_of_reference(
1997 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1998 generate_tie([&](auto i) -> const auto& // return type should be reference
1999 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2001
2002 // tuple of reference to C/Ds tensor descriptors
2003 const auto c_ds_buf_refs = concat_tuple_of_reference(
2004 tie(c_shuffle_block_buf),
2005 generate_tie([&](auto i) -> const auto& // return type should be reference
2006 { return ds_grid_buf[i]; },
2008
2009 // tuple of starting index of C/Ds blockwise copy
2010 const auto idx_c_ds_block_begin =
2013 [&](auto) {
2014 return make_multi_index(block_m_id, 0, block_n_id, 0);
2015 // return make_multi_index(block_work_idx[I0], 0,
2016 // block_work_idx[I1], 0);
2017 },
2019
2020 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2021 c_grid_desc_mblock_mperblock_nblock_nperblock;
2022
2023 using CDEBlockTransferCluster =
2024 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2025 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2026 constexpr index_t scatter_weight_idx = 3; // hack fix felix
2027 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2029 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2031 decltype(c_ds_desc_refs),
2032 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2033 CElementwiseOperation,
2034 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2035 // Sequence support
2036 // arbitray type
2037 Sequence<1,
2038 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2039 1,
2040 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2041 CDEBlockTransferCluster,
2042 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2043 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2044 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2045 3, // index_t SrcVectorDim,
2046 3, // index_t DstVectorDim,
2047 CDEShuffleBlockTransferScalarPerVectors,
2052 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2053 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2054 IndexType,
2055 1, // ScatterDim
2056 true, // OutputScatter: false, only use scatter weights
2057 scatter_weight_idx // ScatterWeightIdx: ascale
2058 >{c_ds_desc_refs,
2059 idx_c_ds_block_begin,
2060 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2061 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2062 c_element_op};
2063
2065 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2066
2067 constexpr auto sfc_c_vgpr =
2068 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2069 NXdlPerWave / NXdlPack,
2070 1,
2071 1,
2072 MXdlPack,
2073 NXdlPack,
2074 M2,
2075 1,
2076 M4,
2077 1>,
2079 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2080 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2081 1,
2082 1,
2083 MXdlPack,
2084 NXdlPack,
2085 M2,
2086 1,
2087 M4,
2088 1>>{};
2089
2090 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2091
2092 // space filling curve for shuffled blockwise C/D/E
2093 constexpr auto sfc_cde_block =
2096 Sequence<1,
2097 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2098 1,
2099 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2100
2101 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2102 constexpr auto EMThreads =
2103 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2104 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2105 constexpr auto ENThreads =
2106 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2107 static_for<0, num_access, 1>{}([&](auto access_id) {
2108 // make sure it's safe to write to LDS
2110
2111 auto dstidx = sfc_cde_block.GetIndex(access_id);
2112 const index_t c_token_pos =
2113 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2114 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2115 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2116 IndexType token_offset = fused_token & 0xffffff;
2117 if constexpr(IsInputGemm)
2118 {
2119 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2120 }
2121 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2122 });
2123
2125
2126 // each thread write its data from VGPR to LDS
2127 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2128 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2129 c_thread_buf_fp32,
2130 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2131 c_shuffle_block_buf);
2132
2133 // make sure it's safe to read from LDS
2135
2136 // each block copy its data from LDS to global
2137 cde_block_copy_lds_and_global.Run(
2138 c_ds_desc_refs,
2139 c_ds_buf_refs,
2140 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2141 tie(c_grid_buf),
2142 scatter_offsets);
2143
2144 if constexpr(access_id < num_access - 1)
2145 {
2146 constexpr auto cde_lds_and_global_step =
2147 sfc_cde_block.GetForwardStep(access_id);
2148
2149 // move on Ds
2150 static_for<0, NumDTensor, 1>{}([&](auto i) {
2151 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2152 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2153 });
2154
2155 // move on E
2156 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2157 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2158 I0,
2159 cde_lds_and_global_step);
2160 }
2161 });
2162 }
2163 }
2164#endif
2165
2166 template <bool HasMainKBlockLoop,
2167 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2168 TailNumber TailNum = TailNumber::Odd>
2169 __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2170 const index_t* p_sorted_expert_ids,
2171 const index_t* p_max_token_id,
2172 const ADataType* p_a_grid,
2173 const AScaleDataType* p_a_scale_grid,
2174 const BDataType* p_b_grid,
2175 const BScaleDataType* p_b_scale_grid,
2176 DsGridPointer& p_ds_grid,
2177 CDataType* p_c_grid,
2178 void* p_shared_0,
2179 void* p_shared_1,
2180 const Problem& problem,
2181 AElementwiseOperation a_element_op,
2182 BElementwiseOperation b_element_op,
2183 CElementwiseOperation c_element_op)
2184 {
2185 ignore = a_element_op;
2186 ignore = b_element_op;
2187 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2188 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2189 problem.MPadded,
2190 problem.K,
2191 problem.KPadded,
2192 problem.StrideA,
2193 problem.AK0);
2194 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2195 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2196 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2197 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2198 problem.MPadded,
2199 problem.N,
2200 problem.NPadded,
2201 problem.StrideC);
2202
2203 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2204 make_tuple(problem.M / (MXdlPack * MPerXdl),
2205 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2206 (KXdlPack * 64 / MPerXdl),
2208
2209 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2210 make_tuple(problem.N / (NXdlPack * NPerXdl),
2211 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2212 (KXdlPack * 64 / NPerXdl),
2214
2215 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2217 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2218
2219 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2220 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2221 if(expert_block_id * MPerBlock >= max_token_id)
2222 return;
2223 const index_t expert_id =
2224 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2225 const auto block_mn = [&]() -> std::pair<int, int> {
2226 if constexpr(NSwizzle)
2227 {
2228 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2229 const index_t prefix_block = ecnt_prefix * problem.NBlock;
2230 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2231 const index_t expert_swizzle =
2232 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2233 const index_t bid_new = blockIdx.x - prefix_block;
2234 const index_t nid = __builtin_amdgcn_readfirstlane(
2235 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2236 const index_t mid =
2237 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2238 return {nid, mid};
2239 }
2240 else
2241 {
2242 return {blockIdx.x, blockIdx.y};
2243 }
2244 }();
2245
2246 const index_t block_n_id = block_mn.first;
2247 const index_t block_m_id = block_mn.second;
2248 const index_t token0 =
2249 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2250
2251 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2252 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2253 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2254 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2255 constexpr auto AKThreads = AK0Threads * AK1Threads;
2256 constexpr auto AMRepeats = MPerBlock / AMThreads;
2257 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2258
2259 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2260 return;
2262 static_for<0, AMRepeats, 1>{}([&](auto m0) {
2263 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2264 index_t token_offset = fused_token & 0xffffff;
2265 if constexpr(!IsInputGemm)
2266 {
2267 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2268 }
2269 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2270 });
2271
2272 const index_t expert_stride =
2273 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2274 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2275 problem.N * (IsInputGemm ? 2 : 1) *
2276 math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2277
2278 // N0, K0, Blocksize*KPack
2279 const index_t n_block_data_idx_on_grid =
2280 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2281
2282 // Gride buffer creation
2283 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2284 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2285 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2286 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2287
2288 // A, B scale buffer
2289 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2290 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2291 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2292 p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2293 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2294
2295 // lds max alignment
2296 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2297
2298 // A matrix in LDS memory, dst of blockwise copy
2299 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2300
2301 // B matrix in LDS memory, dst of blockwise copy
2302 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2303
2304 // A matrix blockwise direct to LDS copy
2308 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2309 ABlockTransferThreadClusterArrangeOrder,
2310 ADataType,
2311 ADataType,
2312 decltype(a_grid_desc_ak0_m_ak1),
2313 decltype(a_block_desc_ak0_m_ak1),
2314 ABlockTransferSrcAccessOrder,
2315 ABlockTransferSrcVectorDim,
2316 2,
2317 ABlockTransferSrcScalarPerVector,
2318 IndexType,
2319 1>(a_grid_desc_ak0_m_ak1,
2320 make_multi_index(0, 0, 0),
2321 a_block_desc_ak0_m_ak1,
2322 make_multi_index(0, 0, 0),
2323 gather_offsets);
2324
2325 // B matrix blockwise copy
2326 auto b_blockwise_copy =
2329 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2330 BBlockTransferThreadClusterArrangeOrder,
2331 BDataType,
2332 BDataType,
2333 decltype(b_grid_desc_bk0_n_bk1),
2334 decltype(b_block_desc_bk0_n_bk1),
2335 BBlockTransferSrcAccessOrder,
2336 BBlockTransferSrcVectorDim,
2337 2,
2338 BBlockTransferSrcScalarPerVector>(
2339 b_grid_desc_bk0_n_bk1,
2340 make_multi_index(0, n_block_data_idx_on_grid, 0),
2341 b_block_desc_bk0_n_bk1,
2342 make_multi_index(0, 0, 0));
2343
2344 // LDS allocation for A and B: be careful of alignment
2345 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2346 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2347
2348 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2349 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2350
2351 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2352 bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2353 a_block_space_size_aligned * sizeof(ADataType)),
2354 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2355
2356 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2357 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2358
2359 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2361 a_block_space_size_aligned * sizeof(ADataType)),
2362 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2363
2364 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2365 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2366
2367 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2368 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2369
2370 // Blockwise GEMM pipeline
2371 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2372 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2373 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2374 decltype(c_thread_buf) c_thread_buf_up;
2375
2377 float,
2378 c_thread_buf.num_of_v_,
2379 c_thread_buf.s_per_v,
2380 true>
2381 c_thread_buf_fp32;
2382
2383 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2384 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2385 KPerBlock);
2386
2387 // a and b scale processing
2388 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2389 const auto waveId_m = wave_idx[I0];
2390 const auto waveId_n = wave_idx[I1];
2391
2392 auto thread_offset_shuffled =
2393 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2394
2395 auto a_thread_offset_m = waveId_m;
2396
2397 // get each thread's offset int the scale tensor
2398 const index_t token_scale_pos = block_m_id * MPerBlock;
2399 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2400 return;
2401
2402 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2403 AScaleDataType,
2404 AScaleDataType,
2405 decltype(a_scale_grid_desc_am_ak),
2406 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2407 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2408 Sequence<0, 1, 2>, // DimAccessOrder
2409 2, // SrcVectorDim
2410 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2411 1, // SrcScalarStrideInVector
2412 true>(a_scale_grid_desc_am_ak,
2413 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2414 0,
2415 thread_offset_shuffled / scale_pack_size_a));
2416
2417 // B scale load
2418 auto b_thread_offset_n = waveId_n;
2419
2420 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2421 BScaleDataType,
2422 BScaleDataType,
2423 decltype(b_scale_grid_desc_bn_ak),
2424 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2425 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2426 Sequence<0, 1, 2>, // DimAccessOrder
2427 2, // SrcVectorDim
2428 KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2429 1, // SrcScalarStrideInVector
2430 true>(b_scale_grid_desc_bn_ak,
2431 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2432 0,
2433 thread_offset_shuffled / scale_pack_size_b));
2434
2435 if constexpr(IsInputGemm)
2436 {
2437 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2438 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2439 p_b_grid_up + expert_id * expert_stride,
2440 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2441
2442 // lds ping pong buffers for up
2443 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
2444 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2445 auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2446 bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2447 a_block_space_size_aligned * sizeof(ADataType) +
2448 b_block_space_size_aligned * sizeof(BDataType)),
2449 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2450 auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2452 a_block_space_size_aligned * sizeof(ADataType) +
2453 b_block_space_size_aligned * sizeof(BDataType)),
2454 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2455
2456 auto b_block_bufs_up = make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2457
2458 auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
2461 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2462 BBlockTransferThreadClusterArrangeOrder,
2463 BDataType,
2464 BDataType,
2465 decltype(b_grid_desc_bk0_n_bk1),
2466 decltype(b_block_desc_bk0_n_bk1),
2467 BBlockTransferSrcAccessOrder,
2468 BBlockTransferSrcVectorDim,
2469 2,
2470 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2471 make_multi_index(0, n_block_data_idx_on_grid, 0),
2472 b_block_desc_bk0_n_bk1,
2473 make_multi_index(0, 0, 0));
2474
2475 const BScaleDataType* p_b_scale_grid_up =
2476 p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2477 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2478 p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2479 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2480
2481 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2482 BScaleDataType,
2483 BScaleDataType,
2484 decltype(b_scale_grid_desc_bn_ak),
2485 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2486 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2487 Sequence<0, 1, 2>, // DimAccessOrder
2488 2, // SrcVectorDim
2489 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2490 1, // SrcScalarStrideInVector
2491 true>(
2492 b_scale_grid_desc_bn_ak,
2493 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2494 0,
2495 thread_offset_shuffled / scale_pack_size_b));
2496
2497 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2498 // A
2499 a_grid_desc_ak0_m_ak1,
2500 a_block_desc_ak0_m_ak1,
2501 a_blockwise_copy,
2502 a_grid_buf,
2503 a_block_bufs,
2504 a_block_slice_copy_step,
2505 // Gate and Up
2506 b_grid_desc_bk0_n_bk1,
2507 b_block_desc_bk0_n_bk1,
2508 b_blockwise_copy,
2509 b_blockwise_copy_up,
2510 b_grid_buf,
2511 b_grid_buf_up,
2512 b_block_bufs,
2513 b_block_bufs_up,
2514 b_block_slice_copy_step,
2515 // C
2516 c_thread_buf,
2517 c_thread_buf_up,
2518 // A scale
2519 a_scale_grid_desc_am_ak,
2520 a_scale_thread_copy,
2521 a_scale_grid_buf,
2522 // B scale
2523 b_scale_grid_desc_bn_ak,
2524 b_scale_thread_copy,
2525 b_scale_thread_copy_up,
2526 b_scale_grid_buf,
2527 b_scale_grid_buf_up,
2528 num_k_block_main_loop);
2529 }
2530 else
2531 {
2532 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2533 a_grid_desc_ak0_m_ak1, // A
2534 a_block_desc_ak0_m_ak1,
2535 a_blockwise_copy,
2536 a_grid_buf,
2537 a_block_bufs,
2538 a_block_slice_copy_step,
2539 b_grid_desc_bk0_n_bk1, // B
2540 b_block_desc_bk0_n_bk1,
2541 b_blockwise_copy,
2542 b_grid_buf,
2543 b_block_bufs,
2544 b_block_slice_copy_step,
2545 c_thread_buf, // C
2546 a_scale_grid_desc_am_ak, // A scale
2547 a_scale_thread_copy,
2548 a_scale_grid_buf,
2549 b_scale_grid_desc_bn_ak, // B scale
2550 b_scale_thread_copy,
2551 b_scale_grid_buf,
2552 num_k_block_main_loop);
2553 }
2554
2555 // shuffle C and write out
2556 {
2557 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2558 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2559 "wrong!");
2560 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2561 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2562 "wrong!");
2563
2564 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2565 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2566
2567 // TODO: hacky, fix it!
2568 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2569 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2570
2571 // TODO: hacky, fix it!
2572 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2573 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2574 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2575
2576 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2577 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2578 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2579 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2580 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2581 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2582 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2583 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2584 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2585 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2586
2587 // mul scales
2588
2589 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2590 static_assert(M5 == 4);
2591 const index_t m1 = get_warp_local_1d_id() / NWave;
2592 const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2593
2594 vector_type<float, 4> topk_weights; // for gemm2 only
2595 static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2596 static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2597 static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2598 static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2599 static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2600 const index_t m_pos = block_m_id * MPerBlock +
2601 m0 * M2 * M1 * M3 * M4 * M5 +
2602 m1 * M2 * M3 * M4 * M5 +
2603 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2604 if constexpr(MulRoutedWeight)
2605 {
2606 topk_weights =
2608 p_ds_grid[I2] + m_pos);
2609 }
2610 static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2611 constexpr index_t c_offset =
2612 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2613 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2614 constexpr auto cidx = Number<c_offset>{};
2615
2616 if constexpr(IsInputGemm) // gu fusion
2617 {
2618 if constexpr(ActivationOperation ==
2619 Activation::silu_and_mul)
2620 {
2621 float gate = c_thread_buf[cidx];
2622 float up = c_thread_buf_up[cidx];
2623 if constexpr(MulRoutedWeight)
2624 {
2625 gate = gate * topk_weights.AsType<float>()[m5];
2626 up = up * topk_weights.AsType<float>()[m5];
2627 }
2629 c_thread_buf_fp32(cidx) = gate * up;
2630 }
2631 else if(ActivationOperation == Activation::gelu_and_mul)
2632 {
2633 float gate = c_thread_buf[cidx];
2634 float up = c_thread_buf_up[cidx];
2635 if constexpr(MulRoutedWeight)
2636 {
2637 gate = gate * topk_weights.AsType<float>()[m5];
2638 up = up * topk_weights.AsType<float>()[m5];
2639 }
2641 c_thread_buf_fp32(cidx) = gate * up;
2642 }
2643 }
2644 else
2645 {
2646 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2647 if constexpr(MulRoutedWeight)
2648 {
2649 c_thread_buf_fp32(cidx) =
2650 topk_weights.AsType<float>()[m5] *
2651 c_thread_buf_fp32[cidx];
2652 }
2653 }
2654 });
2655 });
2656 });
2657 });
2658 });
2659 });
2660
2661 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2663
2664 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2665 static_cast<CShuffleDataType*>(p_shared_0),
2666 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2667
2668 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2669 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2670 make_tuple(
2673 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2674 // shuffle
2675 M1, // M1 = MWave
2676 M2, // M2 * M3 * M4 = MPerXdl
2677 M3,
2678 M4,
2679 M5)),
2683 // per shuffle
2684 N1, // N1 = NWave
2685 N2, // N2 = NXdlPack
2686 N3))), // N3 = NPerXdl
2690 Sequence<>{},
2692
2693 // calculate origin of thread output tensor on global memory
2694 // blockwise GEMM c matrix starting index
2695 const auto c_thread_mtx_on_block =
2696 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2697
2698 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2699 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2700
2701 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2703 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2706
2707 const auto m_thread_data_on_block_idx =
2708 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2709 make_multi_index(m_thread_data_on_block));
2710
2711 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2713 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2716
2717 const auto n_thread_data_on_block_idx =
2718 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2719 make_multi_index(n_thread_data_on_block));
2720
2721 // shuffle: threadwise copy C from VGPR to LDS
2722 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2723 AccDataType,
2724 CShuffleDataType,
2725 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2726 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2728 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2729 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2730 I1,
2731 I1,
2732 M2,
2733 N2,
2734 M3,
2735 I1,
2736 M5,
2737 I1>,
2739 9,
2740 1,
2742 1,
2743 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2745 0,
2746 m_thread_data_on_block_idx[I1],
2747 n_thread_data_on_block_idx[I1],
2748 m_thread_data_on_block_idx[I2],
2749 n_thread_data_on_block_idx[I2],
2750 m_thread_data_on_block_idx[I3],
2751 m_thread_data_on_block_idx[I4],
2752 m_thread_data_on_block_idx[I5],
2753 n_thread_data_on_block_idx[I3]),
2755
2756 using EDataType = CDataType;
2757
2758 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2759 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2760
2761 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2763 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2764
2765 const auto ds_grid_buf = generate_tuple(
2766 [&](auto i) {
2768 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2769 },
2771
2772 // tuple of reference to C/Ds tensor descriptors
2773 const auto c_ds_desc_refs = concat_tuple_of_reference(
2774 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2775 generate_tie([&](auto i) -> const auto& // return type should be reference
2776 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2778
2779 // tuple of reference to C/Ds tensor descriptors
2780 const auto c_ds_buf_refs = concat_tuple_of_reference(
2781 tie(c_shuffle_block_buf),
2782 generate_tie([&](auto i) -> const auto& // return type should be reference
2783 { return ds_grid_buf[i]; },
2785
2786 // tuple of starting index of C/Ds blockwise copy
2787 const auto idx_c_ds_block_begin =
2790 [&](auto) {
2791 return make_multi_index(block_m_id, 0, block_n_id, 0);
2792 // return make_multi_index(block_work_idx[I0], 0,
2793 // block_work_idx[I1], 0);
2794 },
2796
2797 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2798 c_grid_desc_mblock_mperblock_nblock_nperblock;
2799
2800 using CDEBlockTransferCluster =
2801 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2802 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2803 constexpr index_t scatter_weight_idx = 3; // hack fix felix
2804 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2806 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2808 decltype(c_ds_desc_refs),
2809 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2810 CElementwiseOperation,
2811 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2812 // Sequence support
2813 // arbitray type
2814 Sequence<1,
2815 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2816 1,
2817 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2818 CDEBlockTransferCluster,
2819 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2820 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2821 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2822 3, // index_t SrcVectorDim,
2823 3, // index_t DstVectorDim,
2824 CDEShuffleBlockTransferScalarPerVectors,
2829 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2830 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2831 IndexType,
2832 1, // ScatterDim
2833 true, // OutputScatter: false, only use scatter weights
2834 scatter_weight_idx // ScatterWeightIdx: ascale
2835 >{c_ds_desc_refs,
2836 idx_c_ds_block_begin,
2837 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2838 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2839 c_element_op};
2840
2842 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2843
2844 constexpr auto sfc_c_vgpr =
2845 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2846 NXdlPerWave / NXdlPack,
2847 1,
2848 1,
2849 MXdlPack,
2850 NXdlPack,
2851 M2,
2852 1,
2853 M4,
2854 1>,
2856 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2857 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2858 1,
2859 1,
2860 MXdlPack,
2861 NXdlPack,
2862 M2,
2863 1,
2864 M4,
2865 1>>{};
2866
2867 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2868
2869 // space filling curve for shuffled blockwise C/D/E
2870 constexpr auto sfc_cde_block =
2873 Sequence<1,
2874 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2875 1,
2876 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2877
2878 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2879 constexpr auto EMThreads =
2880 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2881 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2882 constexpr auto ENThreads =
2883 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2884 static_for<0, num_access, 1>{}([&](auto access_id) {
2885 // make sure it's safe to write to LDS
2887
2888 auto dstidx = sfc_cde_block.GetIndex(access_id);
2889 const index_t c_token_pos =
2890 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2891 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2892 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2893 IndexType token_offset = fused_token & 0xffffff;
2894 if constexpr(IsInputGemm)
2895 {
2896 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2897 }
2898 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2899 });
2900
2902
2903 // each thread write its data from VGPR to LDS
2904 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2905 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2906 c_thread_buf_fp32,
2907 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2908 c_shuffle_block_buf);
2909
2910 // make sure it's safe to read from LDS
2912
2913 // each block copy its data from LDS to global
2914 cde_block_copy_lds_and_global.Run(
2915 c_ds_desc_refs,
2916 c_ds_buf_refs,
2917 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2918 tie(c_grid_buf),
2919 scatter_offsets);
2920
2921 if constexpr(access_id < num_access - 1)
2922 {
2923 constexpr auto cde_lds_and_global_step =
2924 sfc_cde_block.GetForwardStep(access_id);
2925
2926 // move on Ds
2927 static_for<0, NumDTensor, 1>{}([&](auto i) {
2928 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2929 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2930 });
2931
2932 // move on E
2933 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2934 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2935 I0,
2936 cde_lds_and_global_step);
2937 }
2938 });
2939 }
2940 }
2941};
2942
2943} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm_bns.hpp:48
int32_t index_t
Definition ck.hpp:299
constexpr auto BlockGemmMXPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ PY c_style_pointer_cast(PX p_x)
Definition c_style_pointer_cast.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
Activation
Definition gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition gridwise_moe_gemm.hpp:32
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_moe_mx_gemm.hpp:721
const index_t * p_max_token_id
Definition gridwise_moe_mx_gemm.hpp:783
CDataType * p_c_grid
Definition gridwise_moe_mx_gemm.hpp:789
const AElementwiseOperation a_element_op
Definition gridwise_moe_mx_gemm.hpp:791
const index_t * p_sorted_expert_ids
Definition gridwise_moe_mx_gemm.hpp:782
const CElementwiseOperation c_element_op
Definition gridwise_moe_mx_gemm.hpp:793
const BDataType * p_b_grid
Definition gridwise_moe_mx_gemm.hpp:786
const BScaleDataType * p_b_scale_grid
Definition gridwise_moe_mx_gemm.hpp:787
DsGridPointer p_ds_grid
Definition gridwise_moe_mx_gemm.hpp:788
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_moe_mx_gemm.hpp:722
const index_t * p_sorted_token_ids
Definition gridwise_moe_mx_gemm.hpp:781
const BElementwiseOperation b_element_op
Definition gridwise_moe_mx_gemm.hpp:792
const AScaleDataType * p_a_scale_grid
Definition gridwise_moe_mx_gemm.hpp:785
const ADataType * p_a_grid
Definition gridwise_moe_mx_gemm.hpp:784
index_t MBlock
Definition gridwise_moe_mx_gemm.hpp:715
index_t NPadded
Definition gridwise_moe_mx_gemm.hpp:710
index_t K
Definition gridwise_moe_mx_gemm.hpp:701
index_t N
Definition gridwise_moe_mx_gemm.hpp:700
index_t NumTokens
Definition gridwise_moe_mx_gemm.hpp:697
index_t M
Definition gridwise_moe_mx_gemm.hpp:699
index_t StrideA
Definition gridwise_moe_mx_gemm.hpp:702
index_t StrideScaleB
Definition gridwise_moe_mx_gemm.hpp:705
index_t KRead
Definition gridwise_moe_mx_gemm.hpp:711
index_t NBlock
Definition gridwise_moe_mx_gemm.hpp:716
index_t StrideC
Definition gridwise_moe_mx_gemm.hpp:707
index_t StrideB
Definition gridwise_moe_mx_gemm.hpp:704
__host__ void Print() const
Definition gridwise_moe_mx_gemm.hpp:685
index_t BK0
Definition gridwise_moe_mx_gemm.hpp:714
index_t StrideScaleA
Definition gridwise_moe_mx_gemm.hpp:703
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_moe_mx_gemm.hpp:706
index_t MPadded
Definition gridwise_moe_mx_gemm.hpp:709
index_t KBatch
Definition gridwise_moe_mx_gemm.hpp:708
index_t KPadded
Definition gridwise_moe_mx_gemm.hpp:712
index_t TopK
Definition gridwise_moe_mx_gemm.hpp:698
index_t AK0
Definition gridwise_moe_mx_gemm.hpp:713
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_moe_mx_gemm.hpp:650
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_moe_mx_gemm.hpp:798
index_t b_k_split_offset
Definition gridwise_moe_mx_gemm.hpp:852
index_t b_scale_k_split_offset
Definition gridwise_moe_mx_gemm.hpp:854
index_t a_k_split_offset
Definition gridwise_moe_mx_gemm.hpp:851
index_t a_scale_k_split_offset
Definition gridwise_moe_mx_gemm.hpp:853
Definition gridwise_moe_mx_gemm.hpp:179
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition gridwise_moe_mx_gemm.hpp:1110
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_mx_gemm.hpp:2169
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_moe_mx_gemm.hpp:1372
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
__host__ static __device__ constexpr index_t At(index_t I)
Definition utility/sequence.hpp:53
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:75
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1041
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
Definition dtype_vector.hpp:10
#define CK_ENV(name)
Definition utility/env.hpp:129