gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16#include "ck/utility/env.hpp"
17
18namespace ck {
19
20// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21// kernel function Blockers:
22// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23// two lds chunks.
24// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25// buffer when we declare __shared__ inside blkgemmpipe
26template <typename GridwiseGemm,
27 bool HasMainKBlockLoop,
28 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29 index_t MinimumOccupancy = 1,
31__global__ void
32#if CK_USE_LAUNCH_BOUNDS
33__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34#endif
35 // __attribute__((amdgpu_waves_per_eu(1, 1)))
36 kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
37{
38#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
40 {
41 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42
43 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
44
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
49 karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
50 p_shared,
51 karg);
52 }
53#else
54 ignore = karg;
55#endif // end of if (defined(__gfx9__))
56}
57
58template <typename GridwiseGemm,
59 bool HasMainKBlockLoop,
60 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
61 index_t MinimumOccupancy = 1,
63__global__ void
64#if CK_USE_LAUNCH_BOUNDS
65__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
66#endif
67 // __attribute__((amdgpu_waves_per_eu(1, 1)))
68 kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
69{
70#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
71 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
72 {
73 // Pass two lds pointer is the key to tell compiler that ds_read/write
74 // operate on different lds chunk at same time without order dependecy
75 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
77
78 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
79
80 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
81 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
82 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
83 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
84 karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
85 p_shared_0,
86 p_shared_1,
87 karg);
88 }
89#else
90 ignore = karg;
91#endif // end of if (defined(__gfx9__))
92}
93
94template <typename ALayout,
95 typename BLayout,
96 typename CLayout,
97 typename ADataType,
98 typename BDataType,
99 typename AccDataType,
100 typename CShuffleDataType,
101 typename CDataType,
102 typename AElementwiseOperation,
103 typename BElementwiseOperation,
104 typename CElementwiseOperation,
106 index_t BlockSize,
107 index_t ScaleBlockN, // scale N
108 index_t ScaleBlockK, // scale K
109 index_t MPerBlock,
110 index_t NPerBlock,
111 index_t KPerBlock,
112 index_t AK1Value,
113 index_t BK1Value,
114 index_t MPerXdl,
115 index_t NPerXdl,
116 index_t MXdlPerWave,
117 index_t NXdlPerWave,
118 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 typename ABlockTransferThreadClusterArrangeOrder,
120 typename ABlockTransferSrcAccessOrder,
121 index_t ABlockTransferSrcVectorDim,
122 index_t ABlockTransferSrcScalarPerVector,
123 index_t ABlockTransferDstScalarPerVector_AK1,
124 bool AThreadTransferSrcResetCoordinateAfterRun,
125 index_t ABlockLdsExtraM,
126 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 typename BBlockTransferThreadClusterArrangeOrder,
128 typename BBlockTransferSrcAccessOrder,
129 index_t BBlockTransferSrcVectorDim,
130 index_t BBlockTransferSrcScalarPerVector,
131 index_t BBlockTransferDstScalarPerVector_BK1,
132 bool BThreadTransferSrcResetCoordinateAfterRun,
133 index_t BBlockLdsExtraN,
134 index_t CShuffleMXdlPerWavePerShuffle,
135 index_t CShuffleNXdlPerWavePerShuffle,
136 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
140 typename ComputeTypeA = CDataType,
141 typename ComputeTypeB = ComputeTypeA,
142 bool PermuteA = false,
143 bool PermuteB = false>
145{
147
148 static constexpr auto I0 = Number<0>{};
149 static constexpr auto I1 = Number<1>{};
150 static constexpr auto I2 = Number<2>{};
151 static constexpr auto I3 = Number<3>{};
152 static constexpr auto I4 = Number<4>{};
153 static constexpr auto I5 = Number<5>{};
154 static constexpr auto I6 = Number<6>{};
155 static constexpr auto I7 = Number<7>{};
156
157 // K1 should be Number<...>
158 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
159 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
160 static constexpr auto AK1Number = Number<AK1Value>{};
161 static constexpr auto BK1Number = Number<BK1Value>{};
162
163 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
164 static constexpr bool is_single_rate_mfma =
166 lcm_AK1_BK1 <= 4) ||
169 lcm_AK1_BK1 < 32))
170 ? true
171 : false;
172 static constexpr auto is_scale_mfma = false;
173 static constexpr index_t KPack =
175 MfmaSelector<ComputeTypeA,
176 MPerXdl,
177 NPerXdl,
178 ComputeTypeA,
180 is_scale_mfma>::selected_mfma.k_per_blk);
181
183
184 static constexpr index_t APackedSize = []() {
186 return 2;
187 else
188 return 1;
189 }();
190
191 static constexpr index_t BPackedSize = []() {
193 return 2;
194 else
195 return 1;
196 }();
197
198 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
199 {
200 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
201 }
202
203 __host__ static auto CalculateMPadded(index_t M)
204 {
205 return math::integer_least_multiple(M, MPerBlock);
206 }
207
208 __host__ static auto CalculateNPadded(index_t N)
209 {
210 return math::integer_least_multiple(N, NPerBlock);
211 }
212
213 __host__ static auto CalculateKPadded(index_t K)
214 {
215 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
216 }
217
218 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
219 {
220 auto K_t = K_Batch * KPerBlock;
221 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
222 }
223
224 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
225 {
226 auto K_t = K_Batch * KPerBlock;
227 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
228 }
229
230 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
231 {
232 auto K_t = K_Batch * KPerBlock;
233 return (K + K_t - 1) / K_t * KPerBlock;
234 }
235
236 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
237 {
238 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
239 auto K_t = K_Batch * KReadVec;
240 return (K + K_t - 1) / K_t * KReadVec;
241 }
242
243 __host__ static auto CalculateMBlock(index_t M)
244 {
245 return math::integer_divide_ceil(M, MPerBlock);
246 }
247
248 __host__ static auto CalculateNBlock(index_t N)
249 {
250 return math::integer_divide_ceil(N, NPerBlock);
251 }
252
253 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
254 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
255 {
256 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
257 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
258
260 TileDesc_K0_MN_K1{},
266 }
267
268 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
269 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
270 {
271 const auto a_grid_desc_mraw_kraw = [&]() {
273 {
274 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
275 }
277 {
278 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
279 }
280 }();
281
282 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
283
284 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
285 GemmSpec == GemmSpecialization::MNKPadding)
286 {
287 // pad both M and K
288 const auto a_grid_desc_m_k =
289 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
291 make_right_pad_transform(K, KPad - K)),
294
295 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
296 a_grid_desc_m_k,
301
302 return a_grid_desc_ak0_m_ak1;
303 }
304 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
305 GemmSpec == GemmSpecialization::MNPadding)
306 {
307 // pad M, but not K
308 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
309 a_grid_desc_mraw_kraw,
311 make_right_pad_transform(M, MPad - M)),
314
315 return a_grid_desc_ak0_m_ak1;
316 }
317 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
318 GemmSpec == GemmSpecialization::NKPadding)
319 {
320 // pad K, but not M
321 const auto a_grid_desc_m_k = transform_tensor_descriptor(
322 a_grid_desc_mraw_kraw,
326
327 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
328 a_grid_desc_m_k,
333
334 return a_grid_desc_ak0_m_ak1;
335 }
336 else
337 {
338 // not pad M or K
339 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
340 a_grid_desc_mraw_kraw,
345
346 return a_grid_desc_ak0_m_ak1;
347 }
348 }
349
350 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
351 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
352 {
353 const auto b_grid_desc_nraw_kraw = [&]() {
355 {
356 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
357 }
359 {
360 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
361 }
362 }();
363
364 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
365
367 GemmSpec != GemmSpecialization::Default),
368 "pk_i4_t does not support padding");
369
370 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
371 GemmSpec == GemmSpecialization::MNKPadding)
372 {
373 // pad both N and K
374 const auto b_grid_desc_n_k =
375 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
377 make_right_pad_transform(K, KPad - K)),
380
381 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
382 b_grid_desc_n_k,
387
388 return b_grid_desc_bk0_n_bk1;
389 }
390 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
391 GemmSpec == GemmSpecialization::MNPadding)
392 {
393 // pad N, but not K
394 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
395 b_grid_desc_nraw_kraw,
397 make_right_pad_transform(N, NPad - N)),
400
401 return b_grid_desc_bk0_n_bk1;
402 }
403 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
404 GemmSpec == GemmSpecialization::MKPadding)
405 {
406 // pad K, but not N
407 const auto b_grid_desc_n_k = transform_tensor_descriptor(
408 b_grid_desc_nraw_kraw,
412
413 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
414 b_grid_desc_n_k,
419
420 return b_grid_desc_bk0_n_bk1;
421 }
422 else
423 {
424 if constexpr(!PermuteB)
425 {
426 // not pad N or K
427 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
428 b_grid_desc_nraw_kraw,
433
434 return b_grid_desc_bk0_n_bk1;
435 }
436 else
437 {
438 // Weight Tile Permute
439 constexpr index_t BK01 = KPerBlock / BK1Value;
440 // const index_t BK00 = BK0 / BK01;
441 const index_t BK0_ = StrideB / BK1Value;
442 const index_t BK00 = BK0_ / BK01;
443
444 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
445 make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
446
447 const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
448 b_grid_desc_bk00_n_bk01_bk1_permute,
454
455 return b_grid_desc_bk0_n_bk1_permute;
456 }
457 }
458 }
459
460 template <typename ABlockDesc_AK0_M_AK1>
461 __host__ __device__ static constexpr auto
462 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
463 {
464 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
465
466 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
467 }
468
469 template <typename BBlockDesc_BK0_N_BK1>
470 __host__ __device__ static constexpr auto
471 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
472 {
473 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
474
475 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
476 }
477
478 __host__ __device__ static auto
480 {
481 const auto c_grid_desc_mraw_nraw = [&]() {
483 {
484 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
485 }
487 {
488 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
489 }
490 }();
491
492 // pad M and N
493 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
495 make_right_pad_transform(N, NPad - N)),
498#if 0
499 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
500
501 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
502 GemmSpec == GemmSpecialization::MNKPadding)
503 {
504 // pad M and N
505 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
507 make_right_pad_transform(N, NPad - N)),
510 }
511 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
512 GemmSpec == GemmSpecialization::MKPadding)
513 {
514 // pad M, but not N
516 c_grid_desc_mraw_nraw,
520 }
521 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
522 GemmSpec == GemmSpecialization::NKPadding)
523 {
524 // pad N, but not M
526 c_grid_desc_mraw_nraw,
530 }
531 else
532 {
533 // not pad M or N
534 return c_grid_desc_mraw_nraw;
535 }
536#endif
537 }
538
539 struct Problem
540 {
541 __host__ Problem(index_t M_,
542 index_t N_,
543 index_t K_,
544 index_t StrideA_,
545 index_t StrideB_,
546 index_t StrideC_,
547 index_t StrideScaleB_,
548 index_t KBatch_)
549 : M{M_},
550 N{N_},
551 K{K_},
552 StrideA{StrideA_},
553 StrideB{StrideB_},
554 StrideC{StrideC_},
555 StrideScaleB{StrideScaleB_},
556 KBatch{KBatch_},
559 KRead{CalculateKRead(K_, KBatch_)},
560 KPadded{CalculateKPadded(K_, KBatch_)},
561 AK0{CalculateAK0Padded(K_, KBatch_)},
562 BK0{CalculateBK0Padded(K_, KBatch_)},
565 {
566 }
567
568 __host__ void Print() const
569 {
570 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
571 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
572 << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
573 << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
574 << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
575 << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
576 }
577
578 index_t M;
579 index_t N;
580 index_t K;
590 index_t AK0;
591 index_t BK0;
594 };
595
596 // Argument
598 {
599 __host__ Argument(const ADataType* p_a_grid_,
600 const BDataType* p_b_grid_,
601 CDataType* p_c_grid_,
602 index_t M_,
603 index_t N_,
604 index_t K_,
605 index_t StrideA_,
606 index_t StrideB_,
607 index_t StrideC_,
608 index_t StrideScaleB_,
609 const BScaleType* p_b_scale_grid_,
610 index_t k_batch_,
611 AElementwiseOperation a_element_op_,
612 BElementwiseOperation b_element_op_,
613 CElementwiseOperation c_element_op_,
614 bool is_reduce_ = false)
615 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
616 p_a_grid{p_a_grid_},
617 p_b_grid{p_b_grid_},
618 p_c_grid{p_c_grid_},
619 p_b_scale_grid{p_b_scale_grid_},
623 is_reduce(is_reduce_)
624 {
625 }
626
627 __host__ __device__ inline bool IsReduceAdd() const
628 {
629 return (Problem::KBatch > 1) && is_reduce;
630 }
631
632 __host__ __device__ inline bool IsAtomicAdd() const
633 {
634 return (Problem::KBatch > 1) && (!is_reduce);
635 }
636
637 const ADataType* p_a_grid;
638 const BDataType* p_b_grid;
639 CDataType* p_c_grid;
640
642 const AElementwiseOperation a_element_op;
643 const BElementwiseOperation b_element_op;
644 const CElementwiseOperation c_element_op;
645 bool is_reduce;
646 };
647
648 struct SplitKBatchOffset
649 {
650
651 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
652 {
654 {
655 a_k_split_offset = k_id * karg.KRead / APackedSize;
656 }
658 {
659 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
660 }
661
663 {
664 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
665 }
667 {
668 if constexpr(!PermuteB)
669 {
670 b_k_split_offset = k_id * karg.KRead / BPackedSize;
671 }
672 else
673 {
674 const int k0_offset = karg.KRead * karg.N;
675 b_k_split_offset = k_id * k0_offset / BPackedSize;
676 }
677 }
678
679 // Calculate B scale offset
681 {
682 scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
683 }
685 {
686 scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
687 }
688
689 if(k_id < (karg.KBatch - 1))
690 {
691 karg.K = karg.KRead;
692 }
693 else
694 {
695 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
696 }
697
698 if(karg.IsReduceAdd())
699 {
700 c_reduce_offset = k_id * karg.M * karg.N;
701 }
702 else
703 {
704 c_reduce_offset = 0;
705 }
706 }
707
710 index_t scale_k_split_offset; // New member for scale matrix offset
712 };
713
714 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
715 {
716 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
717 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
718 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
719 // A matrix in LDS memory, dst of blockwise copy
720 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
721 {
725 }
726 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
727 // in some cases.
729 {
730 constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
731 constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
732 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
734 AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
736
737 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
738 a_lds_block_desc,
744
745 constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
746 a_lds_block_desc_permuted,
752
753 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
754 a_lds_block_desc_ak0_mldslayer_m_ak1,
761
762 return a_lds_block_desc_ak0_m_ak1;
763 }
764 else // ColumnMajor A
765 {
766 // kfold and mpair dimension is not always required.
767 // more dimension in merge_transform increase the difficulty of generating immarg offset
768 // for compiler.
769 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
770 constexpr auto M1 = MPerBlock / M0;
771
772 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
773 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
774 constexpr auto KThreadRead = WaveSize / MPerXdl;
775 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
776
777 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
778 ? 1
779 : 128 / (AK1Number * M0 * sizeof(ADataType));
780 constexpr auto KThreadReadPerm =
781 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
782 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
783 : KThreadRead;
784
785 // 1<=mpair<=n0
786 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
787 ? 1
788 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
789 ? M0
790 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
791
792 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
796 Number<kfold * M0 / mpair>{},
798 AK1Number));
799
800 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
801 a_lds_block_desc,
806 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
813
814 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
815 a_lds_block_desc_permuted,
824 Sequence<1>{},
825 Sequence<2>{},
826 Sequence<3>{},
827 Sequence<4>{},
828 Sequence<5>{}),
830 Sequence<2>{},
833 Sequence<6>{},
834 Sequence<7>{}));
835
836 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
837 a_lds_block_desc_unmerged,
840 Number<KThreadWrite / kfold / KThreadReadPerm>{},
848
849 return a_lds_block_desc_ak0_m_ak1;
850 }
851 }
852
853 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
854 {
855 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
856 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
857 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
858 // B matrix in LDS memory, dst of blockwise copy
859 if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
860 {
864 }
866 {
867 // NLdsLayer * K0 as logical Bank
868 constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
869 constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
870 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
872 BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
874
875 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
876 b_lds_block_desc,
882
883 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
884 b_lds_block_desc_permuted,
890
891 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
892 b_lds_block_desc_bk0_nldslayer_n_bk1,
899
900 return b_lds_block_desc_bk0_n_bk1;
901 }
902 else // RowMajor B
903 {
904 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
905 constexpr auto N1 = NPerBlock / N0;
906
907 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
908 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
909 constexpr auto KThreadRead = WaveSize / NPerXdl;
910 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
911
912 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
913 ? 1
914 : 128 / (BK1Number * N0 * sizeof(BDataType));
915 constexpr auto KThreadReadPerm =
916 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
917 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
918 : KThreadRead;
919
920 // 1<=npair<=n0
921 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
922 ? 1
923 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
924 ? N0
925 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
926
927 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
931 Number<kfold * N0 / npair>{},
933 BK1Number));
934
935 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
936 b_lds_block_desc,
941 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
948
949 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
950 b_lds_block_desc_permuted,
959 Sequence<1>{},
960 Sequence<2>{},
961 Sequence<3>{},
962 Sequence<4>{},
963 Sequence<5>{}),
965 Sequence<2>{},
968 Sequence<6>{},
969 Sequence<7>{}));
970
971 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
972 b_lds_block_desc_unmerged,
975 Number<KThreadWrite / kfold / KThreadReadPerm>{},
983
984 return b_lds_block_desc_bk0_n_bk1;
985 }
986 }
987
989 {
990 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
991 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
992
993 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
997 I1,
999
1000 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1001 }
1002
1005 BlkGemmPipelineVer,
1006 BlkGemmPipeSched,
1007 BlockSize,
1008 ADataType,
1009 BDataType,
1010 ComputeTypeA,
1011 AccDataType,
1018 ABlockTransferSrcScalarPerVector,
1019 BBlockTransferSrcScalarPerVector,
1020 MPerBlock,
1021 NPerBlock,
1022 KPerBlock,
1023 MPerXdl,
1024 NPerXdl,
1025 MXdlPerWave,
1026 NXdlPerWave,
1027 KPack>())>;
1028
1029 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1030 {
1031 // LDS allocation for A and B: be careful of alignment
1032 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1033 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1034
1035 // lds max alignment
1036 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1037
1038 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1039 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1040
1041 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1042 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1043
1044 // LDS allocation for C shuffle in LDS
1045 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1047
1048 constexpr auto c_block_size =
1049 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1050
1051 return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1052 b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1053 c_block_size * sizeof(CShuffleDataType));
1054 }
1055
1057
1058 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1059 __host__ static constexpr bool CheckValidity(const Argument& karg)
1060 {
1061 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1062 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1063 "Invalid tuning param!");
1064
1065 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1070 {
1071 if(!(karg.M % MPerBlock == 0))
1072 {
1073 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1074 {
1075 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1076 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1077 << std::endl;
1078 }
1079 return false;
1080 }
1081 }
1082
1083 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1088 {
1089 if(!(karg.N % NPerBlock == 0))
1090 {
1091 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1092 {
1093 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1094 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1095 << std::endl;
1096 }
1097 return false;
1098 }
1099 }
1100
1101 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1105 {
1106
1107 auto K_t = karg.KBatch * KPerBlock;
1108 if(!(karg.K % K_t == 0))
1109 {
1110 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1111 {
1112 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1113 << karg.K << " " << __FILE__ << ":" << __LINE__
1114 << ", in function: " << __func__ << std::endl;
1115 }
1116 return false;
1117 }
1118 }
1119 else
1120 {
1121 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1122 auto K_t = karg.KBatch * KReadVec;
1123 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1124 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1125 {
1126 return false;
1127 }
1128 }
1129
1131 {
1132 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1133 {
1134 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1135 {
1136 std::cout << "Arg K (" << karg.K
1137 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1138 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1139 << __LINE__ << ", in function: " << __func__ << std::endl;
1140 }
1141 return false;
1142 }
1143 }
1144 else
1145 {
1146 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1147 {
1148 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1149 {
1150 std::cout << "Arg M (" << karg.M
1151 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1152 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1153 << __LINE__ << ", in function: " << __func__ << std::endl;
1154 }
1155 return false;
1156 }
1157 }
1158
1160 {
1161 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1162 {
1163 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1164 {
1165 std::cout << "Arg N (" << karg.N
1166 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1167 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1168 << __LINE__ << ", in function: " << __func__ << std::endl;
1169 }
1170 return false;
1171 }
1172 }
1173 else
1174 {
1175 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1176 {
1177 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1178 {
1179 std::cout << "Arg K (" << karg.K
1180 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1181 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1182 << __LINE__ << ", in function: " << __func__ << std::endl;
1183 }
1184 return false;
1185 }
1186 }
1187
1189 {
1190 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1191 {
1192 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1193 {
1194 std::cout << "Arg N (" << karg.N
1195 << ") value is not a multiple of "
1196 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1197 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1198 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1199 << std::endl;
1200 }
1201 return false;
1202 }
1203 }
1204 else
1205 {
1206 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1207 {
1208 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1209 {
1210 std::cout << "Arg M (" << karg.M
1211 << ") value is not a multiple of "
1212 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1213 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1214 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1215 << std::endl;
1216 }
1217 return false;
1218 }
1219 }
1220
1225 {
1226 if(!karg.IsReduceAdd())
1227 {
1228 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1229 {
1230 std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1231 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1232 }
1233 if(karg.KBatch > 1)
1234 {
1235 return false;
1236 }
1237 }
1238 }
1239
1240 // check gridwise gemm pipeline
1241 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1242
1243 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1244 {
1245 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1246 {
1247 return false;
1248 }
1249 }
1250
1251 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1252 return true;
1253 }
1254
1255 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1256 {
1257 const index_t num_loop = K / KPerBlock;
1258
1259 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1260 }
1261
1262 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1263 {
1264 const index_t num_loop = K / KPerBlock;
1265
1266 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1267 }
1268
1269 template <typename CGridDesc>
1270 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1271 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1272 {
1273 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1274 c_grid_desc_m_n,
1279
1280 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1281 }
1282
1283 // return block_id to C matrix tile idx (m0, n0) mapping
1284 // if arch = gfx942
1286 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1287
1288 template <typename AGridDesc_AK0_M_K1,
1289 typename BGridDesc_BK0_N_K1,
1290 typename BScaleGridDesc_BN_AK,
1291 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1292 bool HasMainKBlockLoop,
1293 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1294 TailNumber TailNum = TailNumber::Odd>
1295 __device__ static void Run(const ADataType* p_a_grid,
1296 const BDataType* p_b_grid,
1297 CDataType* p_c_grid,
1298 const BScaleType* p_b_scale_grid,
1299 void* p_shared,
1300 const Problem& problem,
1301 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1302 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1303 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1304 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1305 c_grid_desc_mblock_mperblock_nblock_nperblock)
1306 {
1307 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1308 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1309 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1310 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1312 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1313
1314 // B Scale buffer
1315 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1316 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1317
1318 const AElementwiseOperation a_element_op{};
1319 const BElementwiseOperation b_element_op{};
1320 const CElementwiseOperation c_element_op{};
1321
1322 // divide block work by [M, N]
1323 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1324
1325 const auto block_work_idx =
1326 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1327
1328 if(!block_2_ctile_map.ValidCTileIndex(
1329 block_work_idx,
1330 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1331 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1332 {
1333 return;
1334 }
1335
1336 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1337 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1338
1339 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1340 const index_t m_block_data_idx_on_grid =
1341 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1342
1343 const index_t n_block_data_idx_on_grid =
1344 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1345
1346 // lds max alignment
1347 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1348
1349 // A matrix in LDS memory, dst of blockwise copy
1350 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1351
1352 // B matrix in LDS memory, dst of blockwise copy
1353 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1354
1355 // A matrix blockwise copy
1356 auto a_blockwise_copy =
1358 AElementwiseOperation,
1362 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1363 ABlockTransferThreadClusterArrangeOrder,
1364 ADataType,
1365 ADataType,
1366 decltype(a_grid_desc_ak0_m_ak1),
1367 decltype(a_block_desc_ak0_m_ak1),
1368 ABlockTransferSrcAccessOrder,
1370 ABlockTransferSrcVectorDim,
1371 2,
1372 ABlockTransferSrcScalarPerVector,
1373 ABlockTransferDstScalarPerVector_AK1,
1374 1,
1375 1,
1376 AThreadTransferSrcResetCoordinateAfterRun,
1377 true,
1378 BlockwiseGemmPipe::GlobalBufferNum>(
1379 a_grid_desc_ak0_m_ak1,
1380 make_multi_index(0, m_block_data_idx_on_grid, 0),
1381 a_element_op,
1382 a_block_desc_ak0_m_ak1,
1383 make_multi_index(0, 0, 0),
1385
1386 // B matrix blockwise copy
1387 auto b_blockwise_copy =
1389 BElementwiseOperation,
1393 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1394 BBlockTransferThreadClusterArrangeOrder,
1395 BDataType,
1396 BDataType,
1397 decltype(b_grid_desc_bk0_n_bk1),
1398 decltype(b_block_desc_bk0_n_bk1),
1399 BBlockTransferSrcAccessOrder,
1401 BBlockTransferSrcVectorDim,
1402 2,
1403 BBlockTransferSrcScalarPerVector,
1404 BBlockTransferDstScalarPerVector_BK1,
1405 1,
1406 1,
1407 BThreadTransferSrcResetCoordinateAfterRun,
1408 true,
1409 BlockwiseGemmPipe::GlobalBufferNum>(
1410 b_grid_desc_bk0_n_bk1,
1411 make_multi_index(0, n_block_data_idx_on_grid, 0),
1412 b_element_op,
1413 b_block_desc_bk0_n_bk1,
1414 make_multi_index(0, 0, 0),
1416
1417 // LDS allocation for A and B: be careful of alignment
1418 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1419 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1420
1421 // Cast after lds
1423 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1424
1426 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1427 sizeof(ADataType) /
1428 APackedSize),
1429 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1430
1431 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1432 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1433
1434 // Blockwise GEMM pipeline
1435 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1436 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1437 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1438
1439 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1440 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1441 KPerBlock);
1442
1443 // b scale
1444 // static_assert(KPerBlock <= ScaleBlockK);
1445 static constexpr auto mfma =
1447 static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
1448 static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
1449 static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
1450 static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
1451
1452 static constexpr auto ScaleSliceSizeN = NXdlPerWave;
1453 static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
1454 static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
1455
1456 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1458
1459 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1460 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1461 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1462#if defined(__gfx11__)
1463 auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl +
1464 (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl;
1465 auto b_thread_offset_k = (get_thread_local_1d_id() % 16) / NPerXdl * KPerThread;
1466#else
1467 auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl +
1468 (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl;
1469 auto b_thread_offset_k = (get_thread_local_1d_id() % WaveSize) / NPerXdl * KPerThread;
1470#endif
1471 auto b_scale_thread_copy =
1473 BScaleType,
1474 decltype(b_scale_grid_desc_bn_ak),
1475 decltype(b_scale_thread_desc),
1478 1,
1479 ScaleSliceSizeK,
1480 1,
1481 false>(
1482 b_scale_grid_desc_bn_ak,
1483 make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
1484 b_thread_offset_k / ScaleBlockK));
1485
1486 constexpr auto b_scale_thread_slice_copy_step =
1487 make_tuple(make_multi_index(NWaves * NPerXdl, 0),
1488 make_multi_index(-NPerBlock, 0),
1489 make_multi_index(-NPerBlock, KBlockScaleSliceSizeK));
1490
1491 const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
1492
1493 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1494 a_grid_desc_ak0_m_ak1,
1495 a_block_desc_ak0_m_ak1,
1496 a_blockwise_copy,
1497 a_grid_buf,
1498 a_block_buf,
1499 a_block_slice_copy_step,
1500 b_grid_desc_bk0_n_bk1,
1501 b_block_desc_bk0_n_bk1,
1502 b_blockwise_copy,
1503 b_grid_buf,
1504 b_block_buf,
1505 b_block_slice_copy_step,
1506 c_thread_buf,
1507 b_scale_grid_desc_bn_ak,
1508 b_scale_thread_desc,
1509 b_scale_thread_copy,
1510 b_scale_grid_buf,
1511 b_scale_thread_slice_copy_step,
1512 num_k_block_main_loop,
1513 num_k_block_per_scale);
1514
1515 // shuffle C and write out
1516 {
1517 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1518 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1519 "wrong!");
1520
1521 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1522 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1523
1524 // TODO: hacky, fix it!
1525 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1526 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1527
1528 // TODO: hacky, fix it!
1529 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1530 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1531 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1532
1533 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1534 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1535 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1536 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1537 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1538 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1539 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1540 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1541
1542 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1544
1545 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1546 static_cast<CShuffleDataType*>(p_shared),
1547 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1548
1549 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1550 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1551 make_tuple(
1554 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1555 M1, // M1 = MWave
1556 M2, // M2 * M3 * M4 = MPerXdl
1557 M3,
1558 M4)),
1561 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1562 N1, // N1 = NWave
1563 N2))), // N2 = NPerXdl
1565 make_tuple(
1567
1568 // calculate origin of thread output tensor on global memory
1569 // blockwise GEMM c matrix starting index
1570 const auto c_thread_mtx_on_block =
1571 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1572
1573 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1574 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1575
1576 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1578 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1581
1582 const auto m_thread_data_on_block_idx =
1583 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1584 make_multi_index(m_thread_data_on_block));
1585
1586 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1591
1592 const auto n_thread_data_on_block_idx =
1593 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1594 make_multi_index(n_thread_data_on_block));
1595
1596 // shuffle: threadwise copy C from VGPR to LDS
1597 auto c_thread_copy_vgpr_to_lds =
1599 CShuffleDataType,
1600 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1601 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1603 Sequence<CShuffleMXdlPerWavePerShuffle,
1604 CShuffleNXdlPerWavePerShuffle,
1605 I1,
1606 I1,
1607 M2,
1608 I1,
1609 M4,
1610 I1>,
1612 7,
1613 1,
1615 1,
1616 true>{
1617 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1619 0,
1620 m_thread_data_on_block_idx[I1],
1621 n_thread_data_on_block_idx[I1],
1622 m_thread_data_on_block_idx[I2],
1623 m_thread_data_on_block_idx[I3],
1624 m_thread_data_on_block_idx[I4],
1625 n_thread_data_on_block_idx[I2]),
1627
1628 // shuffle: blockwise copy C from LDS to global
1629 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1630 ThisThreadBlock, // ThreadGroup
1631 CElementwiseOperation, // ElementwiseOperation,
1632 CGlobalMemoryDataOperation, // DstInMemOp,
1633 Sequence<1,
1634 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1635 1,
1636 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1637 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1638 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1639 CShuffleDataType, // typename SrcData,
1640 CDataType, // typename DstData,
1641 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1642 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1643 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1644 3, // index_t VectorDim,
1645 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1646 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1647 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1648 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1649 make_multi_index(0, 0, 0, 0),
1650 c_grid_desc_mblock_mperblock_nblock_nperblock,
1651 make_multi_index(block_m_id, 0, block_n_id, 0),
1652 c_element_op};
1653
1654 // space filling curve for threadwise C in VGPR
1655 constexpr auto sfc_c_vgpr =
1658 Sequence<CShuffleMXdlPerWavePerShuffle,
1659 CShuffleNXdlPerWavePerShuffle,
1660 1,
1661 1,
1662 M2,
1663 1,
1664 M4,
1665 1>>{};
1666
1667 // space filling curve for shuffled blockwise C in global mem
1668 constexpr auto sfc_c_global =
1671 Sequence<1,
1672 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1673 1,
1674 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1675
1676 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1677
1678 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1679
1680 static_for<0, num_access, 1>{}([&](auto access_id) {
1681 // make sure it's safe to write to LDS
1683
1684 // each thread write its data from VGPR to LDS
1685 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1686 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1687 c_thread_buf,
1688 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1689 c_shuffle_block_buf);
1690
1691 // make sure it's safe to read from LDS
1693
1694 // each block copy its data from LDS to global
1695 c_shuffle_block_copy_lds_to_global.Run(
1696 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1697 c_shuffle_block_buf,
1698 c_grid_desc_mblock_mperblock_nblock_nperblock,
1699 c_grid_buf);
1700
1701 if constexpr(access_id < num_access - 1)
1702 {
1703 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1704
1705 // move on C
1706 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1707 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1708 }
1709 });
1710 }
1711 }
1712
1713 template <bool HasMainKBlockLoop,
1714 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1715 TailNumber TailNum = TailNumber::Odd>
1716 __device__ static void Run(const ADataType* p_a_grid,
1717 const BDataType* p_b_grid,
1718 CDataType* p_c_grid,
1719 const BScaleType* p_b_scale_grid,
1720 void* p_shared,
1721 const Problem& problem)
1722 {
1723 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1724 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1725 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1726 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1727 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1728 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1729 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1731 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1732
1733 // B Scale grid
1734 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1735 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1736 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1737 make_tuple(problem.StrideScaleB, 1));
1738
1739 Run<decltype(a_grid_desc_ak0_m_ak1),
1740 decltype(b_grid_desc_bk0_n_bk1),
1741 decltype(b_scale_grid_desc_bn_ak),
1742 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1743 HasMainKBlockLoop,
1744 CGlobalMemoryDataOperation,
1745 TailNum>(p_a_grid,
1746 p_b_grid,
1747 p_c_grid,
1748 p_b_scale_grid,
1749 p_shared,
1750 problem,
1751 a_grid_desc_ak0_m_ak1,
1752 b_grid_desc_bk0_n_bk1,
1753 b_scale_grid_desc_bn_ak,
1754 c_grid_desc_mblock_mperblock_nblock_nperblock);
1755 }
1756
1757 template <typename AGridDesc_AK0_M_K1,
1758 typename BGridDesc_BK0_N_K1,
1759 typename BScaleGridDesc_BN_AK,
1760 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1761 bool HasMainKBlockLoop,
1762 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1763 TailNumber TailNum = TailNumber::Odd>
1764 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1765 const BDataType* p_b_grid,
1766 CDataType* p_c_grid,
1767 const BScaleType* p_b_scale_grid,
1768 void* p_shared_0,
1769 void* p_shared_1,
1770 const Problem& problem,
1771 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1772 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1773 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1774 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1775 c_grid_desc_mblock_mperblock_nblock_nperblock)
1776 {
1777 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1778 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1779 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1780 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1782 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1783
1784 // B Scale buffer
1785 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1786 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1787
1788 const AElementwiseOperation a_element_op{};
1789 const BElementwiseOperation b_element_op{};
1790 const CElementwiseOperation c_element_op{};
1791
1792 // divide block work by [M, N]
1793 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1794
1795 const auto block_work_idx =
1796 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1797
1798 if(!block_2_ctile_map.ValidCTileIndex(
1799 block_work_idx,
1800 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1801 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1802 {
1803 return;
1804 }
1805
1806 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1807 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1808
1809 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1810 const index_t m_block_data_idx_on_grid =
1811 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1812
1813 const index_t n_block_data_idx_on_grid =
1814 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1815
1816 // lds max alignment
1817 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1818
1819 // A matrix in LDS memory, dst of blockwise copy
1820 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1821
1822 // B matrix in LDS memory, dst of blockwise copy
1823 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1824
1825 // A matrix blockwise copy
1826 auto a_blockwise_copy =
1828 AElementwiseOperation,
1832 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1833 ABlockTransferThreadClusterArrangeOrder,
1834 ADataType,
1835 ADataType,
1836 decltype(a_grid_desc_ak0_m_ak1),
1837 decltype(a_block_desc_ak0_m_ak1),
1838 ABlockTransferSrcAccessOrder,
1840 ABlockTransferSrcVectorDim,
1841 2,
1842 ABlockTransferSrcScalarPerVector,
1843 ABlockTransferDstScalarPerVector_AK1,
1844 1,
1845 1,
1846 AThreadTransferSrcResetCoordinateAfterRun,
1847 true,
1848 BlockwiseGemmPipe::GlobalBufferNum>(
1849 a_grid_desc_ak0_m_ak1,
1850 make_multi_index(0, m_block_data_idx_on_grid, 0),
1851 a_element_op,
1852 a_block_desc_ak0_m_ak1,
1853 make_multi_index(0, 0, 0),
1855
1856 // B matrix blockwise copy
1857 auto b_blockwise_copy =
1859 BElementwiseOperation,
1863 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1864 BBlockTransferThreadClusterArrangeOrder,
1865 BDataType,
1866 BDataType,
1867 decltype(b_grid_desc_bk0_n_bk1),
1868 decltype(b_block_desc_bk0_n_bk1),
1869 BBlockTransferSrcAccessOrder,
1871 BBlockTransferSrcVectorDim,
1872 2,
1873 BBlockTransferSrcScalarPerVector,
1874 BBlockTransferDstScalarPerVector_BK1,
1875 1,
1876 1,
1877 BThreadTransferSrcResetCoordinateAfterRun,
1878 true,
1879 BlockwiseGemmPipe::GlobalBufferNum>(
1880 b_grid_desc_bk0_n_bk1,
1881 make_multi_index(0, n_block_data_idx_on_grid, 0),
1882 b_element_op,
1883 b_block_desc_bk0_n_bk1,
1884 make_multi_index(0, 0, 0),
1886
1887 // LDS allocation for A and B: be careful of alignment
1888 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1889 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1890
1891 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1892 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1893
1894 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1895 bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1896 a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
1897 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1898
1899 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1900 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1901
1902 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1904 a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
1905 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1906
1907 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1908 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1909
1910 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1911 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1912
1913 // Blockwise GEMM pipeline
1914 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1915 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1916 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1917
1918 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1919 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1920 KPerBlock);
1921
1922 // B scale
1923 static constexpr auto mfma =
1925 static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
1926 static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
1927 static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
1928 static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
1929
1930 const index_t ScaleSliceSizeN = NXdlPerWave;
1931 static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
1932 static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
1933
1934 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1936
1937 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1938 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1939 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1940#if defined(__gfx11__)
1941 auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl +
1942 (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl;
1943 auto b_thread_offset_k = (get_thread_local_1d_id() % 16) / NPerXdl * KPerThread;
1944#else
1945 auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl +
1946 (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl;
1947 auto b_thread_offset_k = (get_thread_local_1d_id() % WaveSize) / NPerXdl * KPerThread;
1948#endif
1949
1950 auto b_scale_thread_copy =
1952 BScaleType,
1953 decltype(b_scale_grid_desc_bn_ak),
1954 decltype(b_scale_thread_desc),
1957 1,
1958 ScaleSliceSizeK,
1959 1,
1960 false>(
1961 b_scale_grid_desc_bn_ak,
1962 make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
1963 b_thread_offset_k / ScaleBlockK));
1964
1965 constexpr auto b_scale_thread_slice_copy_step =
1966 make_tuple(make_multi_index(NWaves * NPerXdl, 0),
1967 make_multi_index(-NPerBlock, 0),
1968 make_multi_index(-NPerBlock, KBlockScaleSliceSizeK));
1969
1970 const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
1971
1972 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1973 a_grid_desc_ak0_m_ak1,
1974 a_block_desc_ak0_m_ak1,
1975 a_blockwise_copy,
1976 a_grid_buf,
1977 a_block_bufs,
1978 a_block_slice_copy_step,
1979 b_grid_desc_bk0_n_bk1,
1980 b_block_desc_bk0_n_bk1,
1981 b_blockwise_copy,
1982 b_grid_buf,
1983 b_block_bufs,
1984 b_block_slice_copy_step,
1985 c_thread_buf,
1986
1987 b_scale_grid_desc_bn_ak,
1988 b_scale_thread_desc,
1989 b_scale_thread_copy,
1990 b_scale_grid_buf,
1991 b_scale_thread_slice_copy_step,
1992
1993 num_k_block_main_loop,
1994 num_k_block_per_scale);
1995
1996 // shuffle C and write out
1997 {
1998 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1999 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2000 "wrong!");
2001
2002 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2003 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2004
2005 // TODO: hacky, fix it!
2006 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2007 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2008
2009 // TODO: hacky, fix it!
2010 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2011 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2012 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2013
2014 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2015 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2016 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2017 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2018 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2019 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2020 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2021 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2022
2023 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2025
2026 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2027 static_cast<CShuffleDataType*>(p_shared_0),
2028 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2029
2030 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2031 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2032 make_tuple(
2035 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2036 M1, // M1 = MWave
2037 M2, // M2 * M3 * M4 = MPerXdl
2038 M3,
2039 M4)),
2042 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2043 N1, // N1 = NWave
2044 N2))), // N2 = NPerXdl
2046 make_tuple(
2048
2049 // calculate origin of thread output tensor on global memory
2050 // blockwise GEMM c matrix starting index
2051 const auto c_thread_mtx_on_block =
2052 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2053
2054 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2055 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2056
2057 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2059 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2062
2063 const auto m_thread_data_on_block_idx =
2064 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2065 make_multi_index(m_thread_data_on_block));
2066
2067 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2072
2073 const auto n_thread_data_on_block_idx =
2074 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2075 make_multi_index(n_thread_data_on_block));
2076
2077 // shuffle: threadwise copy C from VGPR to LDS
2078 auto c_thread_copy_vgpr_to_lds =
2080 CShuffleDataType,
2081 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2082 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2084 Sequence<CShuffleMXdlPerWavePerShuffle,
2085 CShuffleNXdlPerWavePerShuffle,
2086 I1,
2087 I1,
2088 M2,
2089 I1,
2090 M4,
2091 I1>,
2093 7,
2094 1,
2096 1,
2097 true>{
2098 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2100 0,
2101 m_thread_data_on_block_idx[I1],
2102 n_thread_data_on_block_idx[I1],
2103 m_thread_data_on_block_idx[I2],
2104 m_thread_data_on_block_idx[I3],
2105 m_thread_data_on_block_idx[I4],
2106 n_thread_data_on_block_idx[I2]),
2108
2109 // shuffle: blockwise copy C from LDS to global
2110 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2111 ThisThreadBlock, // ThreadGroup
2112 CElementwiseOperation, // ElementwiseOperation,
2113 CGlobalMemoryDataOperation, // DstInMemOp,
2114 Sequence<1,
2115 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2116 1,
2117 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2118 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2119 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2120 CShuffleDataType, // typename SrcData,
2121 CDataType, // typename DstData,
2122 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2123 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2124 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2125 3, // index_t VectorDim,
2126 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2127 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2128 false> // bool ThreadTransferDstResetCoordinateAfterRun>
2129 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2130 make_multi_index(0, 0, 0, 0),
2131 c_grid_desc_mblock_mperblock_nblock_nperblock,
2132 make_multi_index(block_m_id, 0, block_n_id, 0),
2133 c_element_op};
2134
2135 // space filling curve for threadwise C in VGPR
2136 constexpr auto sfc_c_vgpr =
2139 Sequence<CShuffleMXdlPerWavePerShuffle,
2140 CShuffleNXdlPerWavePerShuffle,
2141 1,
2142 1,
2143 M2,
2144 1,
2145 M4,
2146 1>>{};
2147
2148 // space filling curve for shuffled blockwise C in global mem
2149 constexpr auto sfc_c_global =
2152 Sequence<1,
2153 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2154 1,
2155 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2156
2157 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2158
2159 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2160
2161 static_for<0, num_access, 1>{}([&](auto access_id) {
2162 // make sure it's safe to write to LDS
2164
2165 // each thread write its data from VGPR to LDS
2166 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2167 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2168 c_thread_buf,
2169 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2170 c_shuffle_block_buf);
2171
2172 // make sure it's safe to read from LDS
2174
2175 // each block copy its data from LDS to global
2176 c_shuffle_block_copy_lds_to_global.Run(
2177 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2178 c_shuffle_block_buf,
2179 c_grid_desc_mblock_mperblock_nblock_nperblock,
2180 c_grid_buf);
2181
2182 if constexpr(access_id < num_access - 1)
2183 {
2184 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2185
2186 // move on C
2187 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2188 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2189 }
2190 });
2191 }
2192 }
2193
2194 template <bool HasMainKBlockLoop,
2195 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2196 TailNumber TailNum = TailNumber::Odd>
2197 __device__ static void Run_2Lds(const ADataType* p_a_grid,
2198 const BDataType* p_b_grid,
2199 CDataType* p_c_grid,
2200 const BScaleType* p_b_scale_grid,
2201 void* p_shared_0,
2202 void* p_shared_1,
2203 const Problem& problem)
2204 {
2205 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2206 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2207 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2208 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2209 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2210 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2211
2212 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2214 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2215
2216 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2217 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
2218 math::integer_divide_ceil(problem.K, ScaleBlockK)),
2219 make_tuple(problem.StrideScaleB, 1));
2220
2221 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2222 decltype(b_grid_desc_bk0_n_bk1),
2223 decltype(b_scale_grid_desc_bn_ak),
2224 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2225 HasMainKBlockLoop,
2226 CGlobalMemoryDataOperation,
2227 TailNum>(p_a_grid,
2228 p_b_grid,
2229 p_c_grid,
2230 p_b_scale_grid,
2231 p_shared_0,
2232 p_shared_1,
2233 problem,
2234 a_grid_desc_ak0_m_ak1,
2235 b_grid_desc_bk0_n_bk1,
2236 b_scale_grid_desc_bn_ak,
2237 c_grid_desc_mblock_mperblock_nblock_nperblock);
2238 }
2239};
2240
2241} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:627
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:632
const BScaleType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:641
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:758
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:599
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:761
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:700
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:541
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:709
index_t StrideScaleB
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:584
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:568
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:814
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:651
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:815
index_t scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:710
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:273
static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:261
static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:264
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, BlkGemmPipeSched, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1112
static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:274
static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:263
static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:255
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BScaleType *p_b_scale_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1764
static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:253
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1295
static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:260
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1853
static constexpr index_t BPackedSize
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:292
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1437
static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:251
static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:252
static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:259
static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:258
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr index_t GetKPerXdlops()
Definition xdlops_gemm.hpp:1804
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129