gridwise_gemm_xdl_cshuffle_v1.hpp Source File

gridwise_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename GridwiseGemm, bool HasMainKBlockLoop>
21__global__ void
22#if CK_USE_LAUNCH_BOUNDS
24#endif
25 kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
26{
27#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
28 defined(__gfx12__)
29 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
30 {
31 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
32
33 GridwiseGemm::template Run<HasMainKBlockLoop>(
34 karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
35 }
36#else
37 ignore = karg;
38#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
39}
40
41template <typename GridwiseGemm,
42 typename FloatA,
43 typename FloatB,
44 typename FloatC,
45 bool HasMainKBlockLoop>
46__global__ void
47#if CK_USE_LAUNCH_BOUNDS
49#endif
50 kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
51 const FloatB* __restrict__ p_b_grid,
52 FloatC* __restrict__ p_c_grid,
53 typename GridwiseGemm::Problem problem)
54{
55#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
56 defined(__gfx12__)
57 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
58 {
59 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
60
61 GridwiseGemm::template Run<HasMainKBlockLoop>(
62 p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
63 }
64#else
65 ignore = p_a_grid;
66 ignore = p_b_grid;
67 ignore = p_c_grid;
68 ignore = problem;
69#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
70}
71
72template <typename ALayout,
73 typename BLayout,
74 typename CLayout,
75 typename FloatA,
76 typename FloatB,
77 typename FloatGemmAcc,
78 typename FloatCShuffle,
79 typename FloatC,
80 typename AElementwiseOperation,
81 typename BElementwiseOperation,
82 typename CElementwiseOperation,
84 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
85 index_t NumGemmKPrefetchStage,
86 index_t BlockSize,
87 index_t MPerBlock,
88 index_t NPerBlock,
89 index_t KPerBlock,
90 index_t AK1Value,
91 index_t BK1Value,
92 index_t MPerXdl,
93 index_t NPerXdl,
94 index_t MXdlPerWave,
95 index_t NXdlPerWave,
96 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
97 typename ABlockTransferThreadClusterArrangeOrder,
98 typename ABlockTransferSrcAccessOrder,
99 index_t ABlockTransferSrcVectorDim,
100 index_t ABlockTransferSrcScalarPerVector,
101 index_t ABlockTransferDstScalarPerVector_AK1,
102 bool AThreadTransferSrcResetCoordinateAfterRun,
103 index_t ABlockLdsExtraM,
104 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
105 typename BBlockTransferThreadClusterArrangeOrder,
106 typename BBlockTransferSrcAccessOrder,
107 index_t BBlockTransferSrcVectorDim,
108 index_t BBlockTransferSrcScalarPerVector,
109 index_t BBlockTransferDstScalarPerVector_BK1,
110 bool BThreadTransferSrcResetCoordinateAfterRun,
111 index_t BBlockLdsExtraN,
112 index_t CShuffleMXdlPerWavePerShuffle,
113 index_t CShuffleNXdlPerWavePerShuffle,
114 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
115 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
116 LoopScheduler LoopSched,
118 typename ComputeTypeA = FloatC,
119 typename ComputeTypeB = ComputeTypeA>
121{
122 static constexpr auto I0 = Number<0>{};
123 static constexpr auto I1 = Number<1>{};
124 static constexpr auto I2 = Number<2>{};
125 static constexpr auto I3 = Number<3>{};
126 static constexpr auto I4 = Number<4>{};
127 static constexpr auto I5 = Number<5>{};
128 static constexpr auto I6 = Number<6>{};
129 static constexpr auto I7 = Number<7>{};
130
131 // K1 should be Number<...>
132 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
133 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
134 static constexpr auto AK1Number = Number<AK1Value>{};
135 static constexpr auto BK1Number = Number<BK1Value>{};
136
138
139 __host__ static auto CalculateGridSize(index_t M, index_t N)
140 {
141 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
142 }
143
144 __host__ static auto CalculateMPadded(index_t M)
145 {
146 return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
147 }
148
149 __host__ static auto CalculateNPadded(index_t N)
150 {
151 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
152 }
153
154 __host__ static auto CalculateKPadded(index_t K)
155 {
156 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
157 }
158
159 __host__ static auto CalculateAK0(index_t K)
160 {
161 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
162
163 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
164 GemmSpec == GemmSpecialization::MNKPadding ||
165 GemmSpec == GemmSpecialization::KPadding ||
166 GemmSpec == GemmSpecialization::NKPadding)
167 {
168 return CalculateKPadded(K) / AK1Value;
169 }
170 else
171 {
172 return K / AK1Value;
173 }
174 }
175
176 __host__ static auto CalculateBK0(index_t K)
177 {
178 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
179
180 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
181 GemmSpec == GemmSpecialization::MNKPadding ||
182 GemmSpec == GemmSpecialization::KPadding ||
183 GemmSpec == GemmSpecialization::MKPadding)
184 {
185 return CalculateKPadded(K) / BK1Value;
186 }
187 else
188 {
189 return K / BK1Value;
190 }
191 }
192
193 __host__ static auto CalculateMBlock(index_t M)
194 {
195 return math::integer_divide_floor(M, MPerBlock);
196 }
197
198 __host__ static auto CalculateNBlock(index_t N)
199 {
200 return math::integer_divide_floor(N, NPerBlock);
201 }
202
203 __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
204 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
205 {
206 const auto a_grid_desc_mraw_kraw = [&]() {
208 {
209 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
210 }
212 {
213 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
214 }
215 }();
216
217 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
218
219 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
220 GemmSpec == GemmSpecialization::MNKPadding)
221 {
222 // pad both M and K
223 const auto a_grid_desc_m_k =
224 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
226 make_right_pad_transform(K, KPad - K)),
229
230 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
231 a_grid_desc_m_k,
236
237 return a_grid_desc_ak0_m_ak1;
238 }
239 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
240 GemmSpec == GemmSpecialization::MNPadding)
241 {
242 // pad M, but not K
243 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
244 a_grid_desc_mraw_kraw,
246 make_right_pad_transform(M, MPad - M)),
249
250 return a_grid_desc_ak0_m_ak1;
251 }
252 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
253 GemmSpec == GemmSpecialization::NKPadding)
254 {
255 // pad K, but not M
256 const auto a_grid_desc_m_k = transform_tensor_descriptor(
257 a_grid_desc_mraw_kraw,
261
262 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
263 a_grid_desc_m_k,
268
269 return a_grid_desc_ak0_m_ak1;
270 }
271 else
272 {
273 // not pad M or K
274 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
275 a_grid_desc_mraw_kraw,
280
281 return a_grid_desc_ak0_m_ak1;
282 }
283 }
284
285 __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
286 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
287 {
288 const auto b_grid_desc_nraw_kraw = [&]() {
290 {
291 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
292 }
294 {
295 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
296 }
297 }();
298
299 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
300
301 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
302 GemmSpec == GemmSpecialization::MNKPadding)
303 {
304 // pad both N and K
305 const auto b_grid_desc_n_k =
306 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
308 make_right_pad_transform(K, KPad - K)),
311
312 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
313 b_grid_desc_n_k,
318
319 return b_grid_desc_bk0_n_bk1;
320 }
321 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
322 GemmSpec == GemmSpecialization::MNPadding)
323 {
324 // pad N, but not K
325 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
326 b_grid_desc_nraw_kraw,
328 make_right_pad_transform(N, NPad - N)),
331
332 return b_grid_desc_bk0_n_bk1;
333 }
334 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
335 GemmSpec == GemmSpecialization::MKPadding)
336 {
337 // pad K, but not N
338 const auto b_grid_desc_n_k = transform_tensor_descriptor(
339 b_grid_desc_nraw_kraw,
343
344 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
345 b_grid_desc_n_k,
350
351 return b_grid_desc_bk0_n_bk1;
352 }
353 else
354 {
355 // not pad N or K
356 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
357 b_grid_desc_nraw_kraw,
362
363 return b_grid_desc_bk0_n_bk1;
364 }
365 }
366
367 __host__ __device__ static auto
369 {
370 const auto c_grid_desc_mraw_nraw = [&]() {
372 {
373 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
374 }
376 {
377 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
378 }
379 }();
380
381 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
382
383 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
384 GemmSpec == GemmSpecialization::MNKPadding)
385 {
386 // pad M and N
387 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
389 make_right_pad_transform(N, NPad - N)),
392 }
393 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
394 GemmSpec == GemmSpecialization::MKPadding)
395 {
396 // pad M, but not N
398 c_grid_desc_mraw_nraw,
402 }
403 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
404 GemmSpec == GemmSpecialization::NKPadding)
405 {
406 // pad N, but not M
408 c_grid_desc_mraw_nraw,
412 }
413 else
414 {
415 // not pad M or N
416 return c_grid_desc_mraw_nraw;
417 }
418 }
419
420 struct Problem
421 {
422 __host__ Problem(index_t M_,
423 index_t N_,
424 index_t K_,
425 index_t StrideA_,
426 index_t StrideB_,
427 index_t StrideC_)
428 : M{M_},
429 N{N_},
430 K{K_},
431 StrideA{StrideA_},
432 StrideB{StrideB_},
433 StrideC{StrideC_},
437 AK0{CalculateAK0(K_)},
438 BK0{CalculateBK0(K_)},
441 {
442 }
443
444 __host__ void Print() const
445 {
446 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
447 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
448 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
449 << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
450 << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
451 }
452
466 };
467
468 // Argument
470 {
471 __host__ Argument(const FloatA* p_a_grid_,
472 const FloatB* p_b_grid_,
473 FloatC* p_c_grid_,
474 index_t M_,
475 index_t N_,
476 index_t K_,
477 index_t StrideA_,
478 index_t StrideB_,
479 index_t StrideC_)
480 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
481 p_a_grid{p_a_grid_},
482 p_b_grid{p_b_grid_},
483 p_c_grid{p_c_grid_}
484 {
485 }
486
487 const FloatA* p_a_grid;
488 const FloatB* p_b_grid;
489 FloatC* p_c_grid;
490 };
491
492 // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
495
496 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
497 {
498 // A matrix in LDS memory, dst of blockwise copy
502 }
503
504 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
505 {
506 // B matrix in LDS memory, dst of blockwise copy
510 }
511
513 {
514 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
515 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
516
517 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
521 I1,
523
524 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
525 }
526
527 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
528 {
529 // LDS allocation for A and B: be careful of alignment
530 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
531 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
532
533 // lds max alignment
534 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
535
536 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
537 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
538
539 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
540 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
541
542 // LDS allocation for C shuffle in LDS
543 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
545
546 constexpr auto c_block_size =
547 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
548
549 return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
550 b_block_space_size_aligned * sizeof(ComputeTypeB)),
551 c_block_size * sizeof(FloatCShuffle));
552 }
553
554 template <
555 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
556 __device__ static bool constexpr IsValidCompilationParameter()
557 {
558 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
559 BlockSize,
560 MPerBlock,
561 NPerBlock,
562 MPerXdl,
563 NPerXdl,
564 MXdlPerWave,
565 NXdlPerWave,
566 FloatC,
567 CGlobalMemoryDataOperation>();
568 }
569
570 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
571 __host__ static constexpr bool CheckValidity(const Problem& problem)
572 {
573 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
574 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
575 "Invalid tuning param!");
576
581 {
582 if(!(problem.M % MPerBlock == 0))
583 {
584 return false;
585 }
586 }
587
592 {
593 if(!(problem.N % NPerBlock == 0))
594 {
595 return false;
596 }
597 }
598
603 {
604 if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
605 !(CalculateKPadded(problem.K) % BK1Value == 0))
606 {
607 return false;
608 }
609 }
610 else
611 {
612 if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
613 {
614 return false;
615 }
616 }
617
619 {
620 if(problem.K % ABlockTransferSrcScalarPerVector != 0)
621 {
622 return false;
623 }
624 }
625 else
626 {
627 if(problem.M % ABlockTransferSrcScalarPerVector != 0)
628 {
629 return false;
630 }
631 }
632
634 {
635 if(problem.N % BBlockTransferSrcScalarPerVector != 0)
636 {
637 return false;
638 }
639 }
640 else
641 {
642 if(problem.K % BBlockTransferSrcScalarPerVector != 0)
643 {
644 return false;
645 }
646 }
647
649 {
650 if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
651 {
652 return false;
653 }
654 }
655 else
656 {
657 if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
658 {
659 return false;
660 }
661 }
662
663 // check gridwise gemm pipeline
664 const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
665
666 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
667 {
668 return false;
669 }
670
671 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
672 return true;
673 }
674
675 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
676 {
677 const index_t num_loop = K / KPerBlock;
678
679 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
680 }
681
682 template <typename CGridDesc>
684 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
685 {
686 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
687 c_grid_desc_m_n,
692
693 return c_grid_desc_mblock_mperblock_nblock_nperblock;
694 }
695
696 // return block_id to C matrix tile idx (m0, n0) mapping
698
699 template <bool HasMainKBlockLoop>
700 __device__ static void Run(const FloatA* __restrict__ p_a_grid,
701 const FloatB* __restrict__ p_b_grid,
702 FloatC* __restrict__ p_c_grid,
703 void* __restrict__ p_shared,
704 const Problem& problem)
705 {
706 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
707 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
708 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
709 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
710 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
711 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
712
713 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
715 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
716
717 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
718 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
719 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
720 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
722 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
723
724 const AElementwiseOperation a_element_op{};
725 const BElementwiseOperation b_element_op{};
726 const CElementwiseOperation c_element_op{};
727
728 // divide block work by [M, N]
729 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
730
731 const auto block_work_idx =
732 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
733
734 if(!block_2_ctile_map.ValidCTileIndex(
735 block_work_idx,
736 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
737 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
738 {
739 return;
740 }
741
742 // HACK: this force m/n_block_data_idx_on_grid into SGPR
743 const index_t m_block_data_idx_on_grid =
744 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
745
746 const index_t n_block_data_idx_on_grid =
747 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
748
749 // lds max alignment
750 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
751
752 // A matrix in LDS memory, dst of blockwise copy
753 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
754
755 // B matrix in LDS memory, dst of blockwise copy
756 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
757
758 // A matrix blockwise copy
759 auto a_blockwise_copy =
761 AElementwiseOperation,
765 ABlockTransferThreadClusterLengths_AK0_M_AK1,
766 ABlockTransferThreadClusterArrangeOrder,
767 FloatA,
768 ComputeTypeA,
769 decltype(a_grid_desc_ak0_m_ak1),
770 decltype(a_block_desc_ak0_m_ak1),
771 ABlockTransferSrcAccessOrder,
773 ABlockTransferSrcVectorDim,
774 2,
775 ABlockTransferSrcScalarPerVector,
776 ABlockTransferDstScalarPerVector_AK1,
777 1,
778 1,
779 AThreadTransferSrcResetCoordinateAfterRun,
780 true,
781 NumGemmKPrefetchStage>(
782 a_grid_desc_ak0_m_ak1,
783 make_multi_index(0, m_block_data_idx_on_grid, 0),
784 a_element_op,
785 a_block_desc_ak0_m_ak1,
786 make_multi_index(0, 0, 0),
788
789 // B matrix blockwise copy
790 auto b_blockwise_copy =
792 BElementwiseOperation,
796 BBlockTransferThreadClusterLengths_BK0_N_BK1,
797 BBlockTransferThreadClusterArrangeOrder,
798 FloatB,
799 ComputeTypeB,
800 decltype(b_grid_desc_bk0_n_bk1),
801 decltype(b_block_desc_bk0_n_bk1),
802 BBlockTransferSrcAccessOrder,
804 BBlockTransferSrcVectorDim,
805 2,
806 BBlockTransferSrcScalarPerVector,
807 BBlockTransferDstScalarPerVector_BK1,
808 1,
809 1,
810 BThreadTransferSrcResetCoordinateAfterRun,
811 true,
812 NumGemmKPrefetchStage>(
813 b_grid_desc_bk0_n_bk1,
814 make_multi_index(0, n_block_data_idx_on_grid, 0),
815 b_element_op,
816 b_block_desc_bk0_n_bk1,
817 make_multi_index(0, 0, 0),
819
820 // GEMM definition
821 // c_mtx += transpose(a_mtx) * b_mtx
822 // a_mtx[K0PerBlock, MPerBlock] is in LDS
823 // b_mtx[K0PerBlock, NPerBlock] is in LDS
824 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
825 // register
826 // sanity check
827 constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
828 constexpr bool is_single_rate_mfma =
830 lcm_AK1_BK1 <= 4) ||
831 (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
833 lcm_AK1_BK1 < 32))
834 ? true
835 : false;
836 constexpr auto is_scale_mfma = false;
837 constexpr index_t KPack = math::max(lcm_AK1_BK1,
838 MfmaSelector<ComputeTypeA,
839 MPerXdl,
840 NPerXdl,
841 ComputeTypeB,
842 is_single_rate_mfma,
843 is_scale_mfma>::selected_mfma.k_per_blk);
844
846 BlockSize,
847 ComputeTypeA,
848 ComputeTypeB,
849 FloatGemmAcc,
850 decltype(a_block_desc_ak0_m_ak1),
851 decltype(b_block_desc_bk0_n_bk1),
852 MPerXdl,
853 NPerXdl,
854 MXdlPerWave,
855 NXdlPerWave,
856 KPack,
857 LoopSched>();
858
859 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
860
861 // LDS allocation for A and B: be careful of alignment
862 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
863 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
864
866 static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
867
869 static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
870 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
871
872 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
873 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
874
875 // gridwise GEMM pipeline
876 static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
877 const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
878
879 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
880 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
881 KPerBlock);
882
883 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
884 a_block_desc_ak0_m_ak1,
885 a_blockwise_copy,
886 a_grid_buf,
887 a_block_buf,
888 a_block_slice_copy_step,
889 b_grid_desc_bk0_n_bk1,
890 b_block_desc_bk0_n_bk1,
891 b_blockwise_copy,
892 b_grid_buf,
893 b_block_buf,
894 b_block_slice_copy_step,
895 blockwise_gemm,
896 c_thread_buf,
897 num_k_block_main_loop);
898
899 // shuffle C and write out
900 {
901 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
902 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
903 "wrong!");
904
905 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
906 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
907
908 // TODO: hacky, fix it!
909 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
910 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
911
912 // TODO: hacky, fix it!
913 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
914 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
915 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
916
917 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
918 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
919 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
920 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
921 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
922 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
923 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
924 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
925
926 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
928
929 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
930 static_cast<FloatCShuffle*>(p_shared),
931 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
932
933 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
934 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
938 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
939 M1, // M1 = MWave
940 M2, // M2 * M3 * M4 = MPerXdl
941 M3,
942 M4)),
945 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
946 N1, // N1 = NWave
947 N2))), // N2 = NPerXdl
951
952 // calculate origin of thread output tensor on global memory
953 // blockwise GEMM c matrix starting index
954 const auto c_thread_mtx_on_block =
955 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
956
957 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
958 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
959
960 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
962 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
965
966 const auto m_thread_data_on_block_idx =
967 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
968 make_multi_index(m_thread_data_on_block));
969
970 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
975
976 const auto n_thread_data_on_block_idx =
977 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
978 make_multi_index(n_thread_data_on_block));
979
980 // shuffle: threadwise copy C from VGPR to LDS
981 auto c_thread_copy_vgpr_to_lds =
983 FloatCShuffle,
984 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
985 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
987 Sequence<CShuffleMXdlPerWavePerShuffle,
988 CShuffleNXdlPerWavePerShuffle,
989 I1,
990 I1,
991 M2,
992 I1,
993 M4,
994 I1>,
996 7,
997 1,
999 1,
1000 true>{
1001 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1003 0,
1004 m_thread_data_on_block_idx[I1],
1005 n_thread_data_on_block_idx[I1],
1006 m_thread_data_on_block_idx[I2],
1007 m_thread_data_on_block_idx[I3],
1008 m_thread_data_on_block_idx[I4],
1009 n_thread_data_on_block_idx[I2]),
1011
1012 // shuffle: blockwise copy C from LDS to global
1013 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1014 ThisThreadBlock, // ThreadGroup
1015 CElementwiseOperation, // ElementwiseOperation,
1016 CGlobalMemoryDataOperation, // DstInMemOp,
1017 Sequence<1,
1018 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1019 1,
1020 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1021 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1022 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1023 FloatCShuffle, // typename SrcData,
1024 FloatC, // typename DstData,
1025 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1026 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1027 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1028 3, // index_t VectorDim,
1029 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1030 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1031 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1032 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1033 make_multi_index(0, 0, 0, 0),
1034 c_grid_desc_mblock_mperblock_nblock_nperblock,
1035 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1036 c_element_op};
1037
1038 // space filling curve for threadwise C in VGPR
1039 constexpr auto sfc_c_vgpr =
1042 Sequence<CShuffleMXdlPerWavePerShuffle,
1043 CShuffleNXdlPerWavePerShuffle,
1044 1,
1045 1,
1046 M2,
1047 1,
1048 M4,
1049 1>>{};
1050
1051 // space filling curve for shuffled blockwise C in global mem
1052 constexpr auto sfc_c_global =
1055 Sequence<1,
1056 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1057 1,
1058 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1059
1060 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1061
1062 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1063
1064 static_for<0, num_access, 1>{}([&](auto access_id) {
1065 // make sure it's safe to write to LDS
1067
1068 // each thread write its data from VGPR to LDS
1069 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1070 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1071 c_thread_buf,
1072 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1073 c_shuffle_block_buf);
1074
1075 // make sure it's safe to read from LDS
1077
1078 // each block copy its data from LDS to global
1079 c_shuffle_block_copy_lds_to_global.Run(
1080 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1081 c_shuffle_block_buf,
1082 c_grid_desc_mblock_mperblock_nblock_nperblock,
1083 c_grid_buf);
1084
1085 if constexpr(access_id < num_access - 1)
1086 {
1087 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1088
1089 // move on C
1090 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1091 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1092 }
1093 });
1094 }
1095 }
1096};
1097
1098} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:261
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:471
const FloatB * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:488
const FloatA * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:487
FloatC * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:489
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:460
index_t N
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:454
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:457
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:465
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:463
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:461
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:459
index_t K
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:455
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:464
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:444
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:422
index_t M
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:453
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:462
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:458
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:121
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340