gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
19
20#define DEBUG_LOG 0
21
22namespace ck {
23
24// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
25// kernel function Blockers:
26// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
27// two lds chunks.
28// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
29// buffer when we declare __shared__ inside blkgemmpipe
30template <typename GridwiseGemm,
31 bool HasMainKBlockLoop,
32 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
33 index_t MinimumOccupancy = 1,
35__global__ void
36#if CK_USE_LAUNCH_BOUNDS
37__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
38#endif
39 // __attribute__((amdgpu_waves_per_eu(1, 1)))
40 kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
41{
42#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
44 {
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
48
49 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
50 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
51 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
52 karg.p_ds_grid,
53 karg.p_c_grid,
54 p_shared,
55 karg,
56 karg.a_element_op,
57 karg.b_element_op,
58 karg.c_element_op);
59 }
60#else
61 ignore = karg;
62#endif // end of if (defined(__gfx9__))
63}
64
65template <typename GridwiseGemm,
66 bool HasMainKBlockLoop,
67 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
68 index_t MinimumOccupancy = 1,
70__global__ void
71#if CK_USE_LAUNCH_BOUNDS
72__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
73#endif
74 // __attribute__((amdgpu_waves_per_eu(1, 1)))
75 kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
76{
77#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
78 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
79 {
80 // Pass two lds pointer is the key to tell compiler that ds_read/write
81 // operate on different lds chunk at same time without order dependecy
82 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
83 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
84
85 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
86
87 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
88 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
89 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
90 karg.p_ds_grid,
91 karg.p_c_grid,
92 p_shared_0,
93 p_shared_1,
94 karg,
95 karg.a_element_op,
96 karg.b_element_op,
97 karg.c_element_op);
98 }
99#else
100 ignore = karg;
101#endif // end of if (defined(__gfx9__))
102}
103
104template <typename ALayout,
105 typename BLayout,
106 typename DsLayout,
107 typename CLayout,
108 typename ADataType,
109 typename BDataType,
110 typename AccDataType,
111 typename CShuffleDataType,
112 typename DsDataType,
113 typename CDataType,
114 typename AElementwiseOperation,
115 typename BElementwiseOperation,
116 typename CElementwiseOperation,
118 index_t BlockSize,
119 index_t MPerBlock,
120 index_t NPerBlock,
121 index_t KPerBlock,
122 index_t AK1Value,
123 index_t BK1Value,
124 index_t MPerXdl,
125 index_t NPerXdl,
126 index_t MXdlPerWave,
127 index_t NXdlPerWave,
128 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
129 typename ABlockTransferThreadClusterArrangeOrder,
130 typename ABlockTransferSrcAccessOrder,
131 index_t ABlockTransferSrcVectorDim,
132 index_t ABlockTransferSrcScalarPerVector,
133 index_t ABlockTransferDstScalarPerVector_AK1,
134 bool AThreadTransferSrcResetCoordinateAfterRun,
135 index_t ABlockLdsExtraMCustom,
136 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
137 typename BBlockTransferThreadClusterArrangeOrder,
138 typename BBlockTransferSrcAccessOrder,
139 index_t BBlockTransferSrcVectorDim,
140 index_t BBlockTransferSrcScalarPerVector,
141 index_t BBlockTransferDstScalarPerVector_BK1,
142 bool BThreadTransferSrcResetCoordinateAfterRun,
143 index_t BBlockLdsExtraNCustom,
144 index_t CShuffleMXdlPerWavePerShuffle,
145 index_t CShuffleNXdlPerWavePerShuffle,
146 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
147 typename CDEShuffleBlockTransferScalarPerVectors,
150 typename ComputeTypeA = CDataType,
151 typename ComputeTypeB = ComputeTypeA,
152 typename LDSTypeA = ADataType,
153 typename LDSTypeB = BDataType,
154 bool DoElementwiseBeforeCShuffle = false,
155 bool DirectLoad = false>
157{
160 !DirectLoad);
161
162 static constexpr auto I0 = Number<0>{};
163 static constexpr auto I1 = Number<1>{};
164 static constexpr auto I2 = Number<2>{};
165 static constexpr auto I3 = Number<3>{};
166 static constexpr auto I4 = Number<4>{};
167 static constexpr auto I5 = Number<5>{};
168 static constexpr auto I6 = Number<6>{};
169 static constexpr auto I7 = Number<7>{};
170
172 CDEShuffleBlockTransferScalarPerVectors{}[I0];
173
174 // K1 should be Number<...>
175 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
176 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
177 static constexpr auto AK1Number = Number<AK1Value>{};
178 static constexpr auto BK1Number = Number<BK1Value>{};
179
180 static constexpr bool DirectLoadEnabled = DirectLoad;
181
182 static constexpr index_t NumDTensor = DsDataType::Size();
183
184 static constexpr auto MakeDsGridPointer()
185 {
186 return generate_tuple(
187 [&](auto i) {
188 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
189
190 return static_cast<const DDataType*>(nullptr);
191 },
193 }
194
195 using DsGridPointer = decltype(MakeDsGridPointer());
196
197 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
198 static constexpr bool is_single_rate_mfma =
200 lcm_AK1_BK1 <= 4) ||
202 // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
204 KPerBlock < 128 && MPerXdl == 16))
205 ? true
206 : false;
207 static constexpr auto is_scale_mfma = false;
208 static constexpr index_t KPack =
210 MfmaSelector<ComputeTypeA,
211 MPerXdl,
212 NPerXdl,
213 ComputeTypeB,
215 is_scale_mfma>::selected_mfma.k_per_blk);
216
218
219 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
220 {
221 return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch);
222 }
223
224 __host__ __device__ static auto CalculateMPadded(index_t M)
225 {
226 return math::integer_least_multiple(M, MPerBlock);
227 }
228
229 __host__ __device__ static auto CalculateNPadded(index_t N)
230 {
231 return math::integer_least_multiple(N, NPerBlock);
232 }
233
234 __host__ __device__ static auto CalculateKPadded(index_t K)
235 {
236 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
237 }
238
239 __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
240 {
241 auto K_t = K_Batch * KPerBlock;
242 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
243 }
244
245 __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
246 {
247 auto K_t = K_Batch * KPerBlock;
248 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
249 }
250
251 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
252 {
253 auto K_t = K_Batch * KPerBlock;
254 return (K + K_t - 1) / K_t * KPerBlock;
255 }
256
257 __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
258 {
259 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
260 auto K_t = K_Batch * KReadVec;
261 return (K + K_t - 1) / K_t * KReadVec;
262 }
263
264 __host__ __device__ static auto CalculateMBlock(index_t M)
265 {
266 return math::integer_divide_ceil(M, MPerBlock);
267 }
268
269 __host__ __device__ static auto CalculateNBlock(index_t N)
270 {
271 return math::integer_divide_ceil(N, NPerBlock);
272 }
273
274 template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
275 __host__ __device__ static auto TransformGrid(GridDesc_K0_MN_K1_T& desc)
276 {
277
278 if constexpr(!DirectLoad)
279 {
280 return desc;
281 }
282 else
283 {
284 const index_t K = desc.GetLength(I0) * desc.GetLength(I2);
285 const index_t MN = desc.GetLength(I1);
286
287 const auto desc_unmerged = transform_tensor_descriptor(
288 desc,
289 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)),
294
295 const auto desc_permuted = transform_tensor_descriptor(
296 desc_unmerged,
302
304 desc_permuted,
306 make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)),
311 }
312 }
313
314 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
315 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
316 {
317 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
318 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
319
320 if constexpr(!DirectLoad)
321 {
323 TileDesc_K0_MN_K1{},
330 }
331 else
332 {
333 constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
334
335 constexpr auto desc = transform_tensor_descriptor(
336 TileDesc_K0_MN_K1{},
341
343 desc,
350 }
351 }
352
353 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
354 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
355 {
356 const auto a_grid_desc_mraw_kraw = [&]() {
358 {
359 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
360 }
362 {
363 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
364 }
365 }();
366
367 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
368
369 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
370 GemmSpec == GemmSpecialization::MNKPadding)
371 {
372 // pad both M and K
373 const auto a_grid_desc_m_k =
374 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
376 make_right_pad_transform(K, KPad - K)),
379
380 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
381 a_grid_desc_m_k,
386
387 return a_grid_desc_ak0_m_ak1;
388 }
389 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
390 GemmSpec == GemmSpecialization::MNPadding)
391 {
392 // pad M, but not K
393 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
394 a_grid_desc_mraw_kraw,
396 make_right_pad_transform(M, MPad - M)),
399
400 return a_grid_desc_ak0_m_ak1;
401 }
402 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
403 GemmSpec == GemmSpecialization::NKPadding)
404 {
405 // pad K, but not M
406 const auto a_grid_desc_m_k = transform_tensor_descriptor(
407 a_grid_desc_mraw_kraw,
411
412 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
413 a_grid_desc_m_k,
418
419 return a_grid_desc_ak0_m_ak1;
420 }
421 else
422 {
423 // not pad M or K
424 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
425 a_grid_desc_mraw_kraw,
430
431 return a_grid_desc_ak0_m_ak1;
432 }
433 }
434
435 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
436 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
437 {
438 const auto b_grid_desc_nraw_kraw = [&]() {
440 {
441 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
442 }
444 {
445 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
446 }
447 }();
448
449 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
450
451 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
452 GemmSpec == GemmSpecialization::MNKPadding)
453 {
454 // pad both N and K
455 const auto b_grid_desc_n_k =
456 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
458 make_right_pad_transform(K, KPad - K)),
461
462 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
463 b_grid_desc_n_k,
468
469 return b_grid_desc_bk0_n_bk1;
470 }
471 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
472 GemmSpec == GemmSpecialization::MNPadding)
473 {
474 // pad N, but not K
475 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
476 b_grid_desc_nraw_kraw,
478 make_right_pad_transform(N, NPad - N)),
481
482 return b_grid_desc_bk0_n_bk1;
483 }
484 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
485 GemmSpec == GemmSpecialization::MKPadding)
486 {
487 // pad K, but not N
488 const auto b_grid_desc_n_k = transform_tensor_descriptor(
489 b_grid_desc_nraw_kraw,
493
494 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
495 b_grid_desc_n_k,
500
501 return b_grid_desc_bk0_n_bk1;
502 }
503 else
504 {
505 // not pad N or K
506 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
507 b_grid_desc_nraw_kraw,
512
513 return b_grid_desc_bk0_n_bk1;
514 }
515 }
516
517 template <typename ABlockDesc_AK0_M_AK1>
518 __host__ __device__ static constexpr auto
519 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
520 {
521 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
522
523 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
524 }
525
526 template <typename BBlockDesc_BK0_N_BK1>
527 __host__ __device__ static constexpr auto
528 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
529 {
530 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
531
532 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
533 }
534
535 template <typename ELayout>
536 __host__ __device__ static auto
538 {
539 const auto c_grid_desc_mraw_nraw = [&]() {
541 {
542 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
543 }
545 {
546 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
547 }
548 else
549 {
550 static_assert(false,
551 "The layout configuration is not supported! "
552 "Only support Row & Col major.");
553 }
554 }();
555
556 // pad M and N
557 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
559 make_right_pad_transform(N, NPad - N)),
562#if 0
563 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
564
565 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
566 GemmSpec == GemmSpecialization::MNKPadding)
567 {
568 // pad M and N
569 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
571 make_right_pad_transform(N, NPad - N)),
574 }
575 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
576 GemmSpec == GemmSpecialization::MKPadding)
577 {
578 // pad M, but not N
580 c_grid_desc_mraw_nraw,
584 }
585 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
586 GemmSpec == GemmSpecialization::NKPadding)
587 {
588 // pad N, but not M
590 c_grid_desc_mraw_nraw,
594 }
595 else
596 {
597 // not pad M or N
598 return c_grid_desc_mraw_nraw;
599 }
600#endif
601 }
602
603 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
604 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
605 {
606 return generate_tuple(
607 [&](auto i) {
608 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
609 return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
610 },
612 }
613
614 template <typename DsGridDesc>
616 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
617 {
618 return generate_tuple(
619 [&](auto i) {
621 ds_grid_desc_m_n[i], MBlock, NBlock);
622 },
624 }
625
626 struct Problem
627 {
628 __host__ __device__ Problem() = default;
629 __host__ __device__ Problem(index_t M_,
630 index_t N_,
631 index_t K_,
632 index_t StrideA_,
633 index_t StrideB_,
634 std::array<index_t, NumDTensor> StrideDs_,
635 index_t StrideC_,
636 index_t KBatch_)
637 : M{M_},
638 N{N_},
639 K{K_},
640 StrideA{StrideA_},
641 StrideB{StrideB_},
642 StrideDs{StrideDs_},
643 StrideC{StrideC_},
644 KBatch{KBatch_},
647 KRead{CalculateKRead(K_, KBatch_)},
648 KPadded{CalculateKPadded(K_, KBatch_)},
649 AK0{CalculateAK0Padded(K_, KBatch_)},
650 BK0{CalculateBK0Padded(K_, KBatch_)},
653 {
654 }
655
656 __host__ void Print() const
657 {
658 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
659 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
660 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
661 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
662 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
663 << "NBlock: " << NBlock << "}" << std::endl;
664 }
665
671 std::array<index_t, NumDTensor> StrideDs;
682 };
683
684 // Argument
686 {
687 __host__ Argument() = default;
688 __host__ Argument(const ADataType* p_a_grid_,
689 const BDataType* p_b_grid_,
690 std::array<const void*, NumDTensor> p_ds_grid_,
691 CDataType* p_c_grid_,
692 index_t M_,
693 index_t N_,
694 index_t K_,
695 index_t StrideA_,
696 index_t StrideB_,
697 std::array<index_t, NumDTensor> StrideDs_,
698 index_t StrideC_,
699 index_t k_batch_,
700 AElementwiseOperation a_element_op_,
701 BElementwiseOperation b_element_op_,
702 CElementwiseOperation c_element_op_)
703 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
704 p_a_grid{p_a_grid_},
705 p_b_grid{p_b_grid_},
706 p_ds_grid{},
707 p_c_grid{p_c_grid_},
708 a_element_op{a_element_op_},
709 b_element_op{b_element_op_},
710 c_element_op{c_element_op_}
711 {
712
713 // populate pointer, desc for Ds
714 static_for<0, NumDTensor, 1>{}([&](auto i) {
715 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
716
717 // D pointer
718 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
719 });
720 }
721
722 const ADataType* p_a_grid;
723 const BDataType* p_b_grid;
725 CDataType* p_c_grid;
726
727 AElementwiseOperation a_element_op;
728 BElementwiseOperation b_element_op;
729 CElementwiseOperation c_element_op;
730 };
731
733 {
734 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
735 {
737 {
738 a_k_split_offset = k_id * karg.KRead;
739 }
741 {
742 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
743 }
744
746 {
747 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
748 }
750 {
751 b_k_split_offset = k_id * karg.KRead;
752 }
753
754 if(k_id < karg.KBatch - 1)
755 {
756 karg.K = karg.KRead;
757 }
758 else
759 {
760 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
761 }
762 }
763
766 };
767
768 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
769 {
770 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
771 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
772 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
773#if defined(__gfx950__)
774 // Force use padded layout on gfx950 to reduce bank conflicts
775 constexpr index_t ABlockLdsExtraM = 1;
776#else
777 constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
778#endif
779
780 // A matrix in LDS memory, dst of blockwise copy
781 if constexpr(DirectLoad)
782 {
786 }
787 else if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
788 {
789 // bank conflict when writting the data into LDS, but don't worry, we have whole entire
790 // loop to hide it in v4. it may give you some benefit from less valu in compute address
794 }
795 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
796 // in some cases.
798 {
799 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
800 ? 1
801 : 32 * 4 / KPerBlock / sizeof(LDSTypeA);
802 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
804 AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
806
807 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
808 a_lds_block_desc,
814
815 constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
816 a_lds_block_desc_permuted,
822
823 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
824 a_lds_block_desc_ak0_mldslayer_m_ak1,
831
832 return a_lds_block_desc_ak0_m_ak1;
833 }
834 else // ColumnMajor A
835 {
836 // kfold and mpair dimension is not always required.
837 // more dimension in merge_transform increase the difficulty of generating immarg offset
838 // for compiler.
839 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
840 constexpr auto M1 = MPerBlock / M0;
841
842 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
843 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
844 constexpr auto KThreadRead = WaveSize / MPerXdl;
845 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
846
847 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
848 ? 1
849 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
850 constexpr auto KThreadReadPerm =
851 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
852 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
853 : KThreadRead;
854
855 // 1<=mpair<=n0
856 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
857 ? 1
858 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
859 ? M0
860 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
861
862 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
866 Number<kfold * M0 / mpair>{},
868 AK1Number));
869
870 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
871 a_lds_block_desc,
876 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
883
884 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
885 a_lds_block_desc_permuted,
894 Sequence<1>{},
895 Sequence<2>{},
896 Sequence<3>{},
897 Sequence<4>{},
898 Sequence<5>{}),
900 Sequence<2>{},
903 Sequence<6>{},
904 Sequence<7>{}));
905
906 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
907 a_lds_block_desc_unmerged,
910 Number<KThreadWrite / kfold / KThreadReadPerm>{},
918
919 return a_lds_block_desc_ak0_m_ak1;
920 }
921 }
922
923 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
924 {
925 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
926 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
927 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
928#if defined(__gfx950__)
929 // Force use padded layout on gfx950 to reduce bank conflicts
930 constexpr index_t BBlockLdsExtraN = 1;
931#else
932 constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
933#endif
934
935 // B matrix in LDS memory, dst of blockwise copy
936 if constexpr(DirectLoad)
937 {
941 }
942 else if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
943 {
944 // bank conflict when writting the data into LDS, but don't worry, we have whole entire
945 // loop to hide it in v4. it may give you some benefit from less valu in compute address
949 }
951 {
952 // NLdsLayer * K0 as logical Bank
953 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
954 ? 1
955 : 32 * 4 / KPerBlock / sizeof(LDSTypeB);
956 ;
957 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
959 BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
961
962 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
963 b_lds_block_desc,
969
970 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
971 b_lds_block_desc_permuted,
977
978 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
979 b_lds_block_desc_bk0_nldslayer_n_bk1,
986
987 return b_lds_block_desc_bk0_n_bk1;
988 }
989 else // RowMajor B
990 {
991 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
992 constexpr auto N1 = NPerBlock / N0;
993
994 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
995 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
996 constexpr auto KThreadRead = WaveSize / NPerXdl;
997 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
998
999 constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
1000 ? 1
1001 : 128 / (BK1Number * N0 * sizeof(LDSTypeB));
1002 constexpr auto KThreadReadPerm =
1003 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1004 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1005 : KThreadRead;
1006
1007 // 1<=npair<=n0
1008 constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128)
1009 ? 1
1010 : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0
1011 ? N0
1012 : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB)));
1013
1014 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1018 Number<kfold * N0 / npair>{},
1019 Number<npair>{},
1020 BK1Number));
1021
1022 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1023 b_lds_block_desc,
1024 make_tuple(
1028 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1031 make_tuple(
1033 make_tuple(
1035
1036 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1037 b_lds_block_desc_permuted,
1038 make_tuple(
1046 Sequence<1>{},
1047 Sequence<2>{},
1048 Sequence<3>{},
1049 Sequence<4>{},
1050 Sequence<5>{}),
1052 Sequence<2>{},
1055 Sequence<6>{},
1056 Sequence<7>{}));
1057
1058 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1059 b_lds_block_desc_unmerged,
1062 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1063 Number<kfold>{},
1070
1071 return b_lds_block_desc_bk0_n_bk1;
1072 }
1073 }
1074
1076 {
1077 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1078 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1079
1080 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1082 make_tuple(I1,
1084 I1,
1086
1087 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1088 }
1089
1092 BlkGemmPipelineVer,
1093 BlkGemmPipeSched,
1094 BlockSize,
1095 LDSTypeA,
1096 LDSTypeB,
1097 ComputeTypeA,
1098 AccDataType,
1105 ABlockTransferSrcScalarPerVector,
1106 BBlockTransferSrcScalarPerVector,
1107 MPerBlock,
1108 NPerBlock,
1109 KPerBlock,
1110 MPerXdl,
1111 NPerXdl,
1112 MXdlPerWave,
1113 NXdlPerWave,
1114 KPack,
1115 DirectLoad>())>;
1116
1117 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1118 {
1119 // LDS allocation for A and B: be careful of alignment
1120 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1121 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1122
1123 // lds max alignment
1124 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1125
1126 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1127 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1128
1129 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1130 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1131
1132 // LDS allocation for C shuffle in LDS
1133 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1135
1136 constexpr auto c_block_size =
1137 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1138
1139 return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
1140 b_block_space_size_aligned * sizeof(LDSTypeB)),
1141 c_block_size * sizeof(CShuffleDataType));
1142 }
1143
1144 template <
1145 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
1146 __device__ static bool constexpr IsValidCompilationParameter()
1147 {
1148 constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter<
1149 BlockSize,
1150 MPerBlock,
1151 NPerBlock,
1152 MPerXdl,
1153 NPerXdl,
1154 MXdlPerWave,
1155 NXdlPerWave,
1156 CDataType,
1157 CGlobalMemoryDataOperation_>();
1158 if constexpr(!valid)
1159 {
1160 return false;
1161 }
1162
1163 using MfmaInst = MfmaSelector<ComputeTypeA,
1164 MPerXdl,
1165 NPerXdl,
1166 ComputeTypeB,
1169
1170 constexpr index_t KPerThread =
1171 KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops());
1172 if constexpr(KPerThread % KPack != 0)
1173 {
1174 static_assert(0);
1175 return false;
1176 }
1177
1178 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
1179 {
1180 return false;
1181 }
1182 return true;
1183 }
1184
1185 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1186 __host__ static constexpr bool CheckValidity(const Argument& karg)
1187 {
1188 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1189 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1190 "Invalid tuning param!");
1191
1192 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
1193 {
1194 return false;
1195 }
1196
1197 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1202 {
1203 if(!(karg.M % MPerBlock == 0))
1204 {
1205#if DEBUG_LOG
1206 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1207 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1208 << std::endl;
1209
1210#endif // DEBUG_LOG
1211 return false;
1212 }
1213 }
1214
1215 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1220 {
1221 if(!(karg.N % NPerBlock == 0))
1222 {
1223#if DEBUG_LOG
1224 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1225 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1226 << std::endl;
1227
1228#endif // DEBUG_LOG
1229 return false;
1230 }
1231 }
1232
1233 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1237 {
1238
1239 auto K_t = karg.KBatch * KPerBlock;
1240 if(!(karg.K % K_t == 0))
1241 {
1242#if DEBUG_LOG
1243 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1244 << karg.K << " " << __FILE__ << ":" << __LINE__
1245 << ", in function: " << __func__ << std::endl;
1246
1247#endif // DEBUG_LOG
1248 return false;
1249 }
1250 }
1251 else
1252 {
1253 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1254 auto K_t = karg.KBatch * KReadVec;
1255 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1256 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1257 {
1258 return false;
1259 }
1260 }
1261
1263 {
1264 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1265 {
1266#if DEBUG_LOG
1267 std::cout << "Arg K (" << karg.K
1268 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1269 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1270 << __LINE__ << ", in function: " << __func__ << std::endl;
1271
1272#endif // DEBUG_LOG
1273 return false;
1274 }
1275 }
1276 else
1277 {
1278 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1279 {
1280#if DEBUG_LOG
1281 std::cout << "Arg M (" << karg.M
1282 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1283 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1284 << __LINE__ << ", in function: " << __func__ << std::endl;
1285
1286#endif // DEBUG_LOG
1287 return false;
1288 }
1289 }
1290
1292 {
1293 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1294 {
1295#if DEBUG_LOG
1296 std::cout << "Arg N (" << karg.N
1297 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1298 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1299 << __LINE__ << ", in function: " << __func__ << std::endl;
1300
1301#endif // DEBUG_LOG
1302 return false;
1303 }
1304 }
1305 else
1306 {
1307 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1308 {
1309#if DEBUG_LOG
1310 std::cout << "Arg K (" << karg.K
1311 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1312 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1313 << __LINE__ << ", in function: " << __func__ << std::endl;
1314
1315#endif // DEBUG_LOG
1316 return false;
1317 }
1318 }
1319
1321 {
1323 {
1324#if DEBUG_LOG
1325 std::cout << "Arg N (" << karg.N
1326 << ") value is not a multiple of "
1327 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1328 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1329 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1330
1331#endif // DEBUG_LOG
1332 return false;
1333 }
1334 }
1335 else
1336 {
1338 {
1339#if DEBUG_LOG
1340 std::cout << "Arg M (" << karg.M
1341 << ") value is not a multiple of "
1342 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1343 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1344 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1345
1346#endif // DEBUG_LOG
1347 return false;
1348 }
1349 }
1350
1351 // check gridwise gemm pipeline
1352 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1353
1354 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1355 {
1356 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1357 {
1358 return false;
1359 }
1360 }
1361
1362 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1363 if(!(karg.M * karg.K * sizeof(ADataType) <= TwoGB &&
1364 karg.N * karg.K * sizeof(BDataType) <= TwoGB &&
1365 karg.M * karg.N * sizeof(CDataType) <= TwoGB))
1366 {
1367 return false;
1368 }
1369
1370 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1371 return true;
1372 }
1373
1374 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1375 {
1376 const index_t num_loop = K / KPerBlock;
1377
1378 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1379 }
1380
1381 __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1382 {
1383 const index_t num_loop = K / KPerBlock;
1384
1385 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1386 }
1387
1388 template <typename CGridDesc>
1390 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1391 {
1392 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1393 c_grid_desc_m_n,
1398
1399 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1400 }
1401
1402 // return block_id to C matrix tile idx (m0, n0) mapping
1403 // if arch = gfx942
1405
1406 template <bool HasMainKBlockLoop,
1407 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1408 TailNumber TailNum = TailNumber::Odd>
1409 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
1410 const BDataType* __restrict__ p_b_grid,
1411 DsGridPointer& p_ds_grid,
1412 CDataType* __restrict__ p_c_grid,
1413 void* __restrict__ p_shared,
1414 const Problem& problem,
1415 AElementwiseOperation a_element_op,
1416 BElementwiseOperation b_element_op,
1417 CElementwiseOperation c_element_op)
1418 {
1419 const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
1421 p_a_grid,
1422 p_b_grid,
1423 p_ds_grid,
1424 p_c_grid,
1425 p_shared,
1426 problem,
1427 a_element_op,
1428 b_element_op,
1429 c_element_op,
1430 block_2_ctile_map);
1431 }
1432
1433 template <typename Block2CTileMap,
1434 bool HasMainKBlockLoop,
1435 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1436 TailNumber TailNum = TailNumber::Odd>
1437 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
1438 const BDataType* __restrict__ p_b_grid,
1439 DsGridPointer& p_ds_grid,
1440 CDataType* __restrict__ p_c_grid,
1441 void* __restrict__ p_shared,
1442 const Problem& problem,
1443 AElementwiseOperation a_element_op,
1444 BElementwiseOperation b_element_op,
1445 CElementwiseOperation c_element_op,
1446 const Block2CTileMap& block_2_ctile_map)
1447 {
1448 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1449 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1450 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1451 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1452
1453 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1454 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1455 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1456 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1457
1459 p_a_grid,
1460 p_b_grid,
1461 p_ds_grid,
1462 p_c_grid,
1463 p_shared,
1464 problem,
1465 a_element_op,
1466 b_element_op,
1467 c_element_op,
1468 block_2_ctile_map,
1469 a_grid_desc_ak0_m_ak1,
1470 b_grid_desc_bk0_n_bk1,
1471 ds_grid_desc_m_n,
1472 c_grid_desc_m_n);
1473 }
1474
1475 template <bool HasMainKBlockLoop,
1476 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1477 TailNumber TailNum,
1478 typename Block2CTileMap,
1479 typename AGridDesc_AK0_M_K1,
1480 typename BGridDesc_BK0_N_K1,
1481 typename DsGridDesc_M_N,
1482 typename CGridDesc_M_N>
1483 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
1484 const BDataType* __restrict__ p_b_grid,
1485 DsGridPointer& p_ds_grid,
1486 CDataType* __restrict__ p_c_grid,
1487 void* __restrict__ p_shared,
1488 const Problem& problem,
1489 [[maybe_unused]] AElementwiseOperation a_element_op,
1490 [[maybe_unused]] BElementwiseOperation b_element_op,
1491 CElementwiseOperation c_element_op,
1492 const Block2CTileMap& block_2_ctile_map,
1493 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1494 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1495 const DsGridDesc_M_N& ds_grid_desc_m_n,
1496 const CGridDesc_M_N& c_grid_desc_m_n)
1497 {
1498
1499 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1500 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1501 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1502 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1503
1504 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1506 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1507
1509 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1510
1511 const auto block_work_idx =
1512 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1513
1514 if(!block_2_ctile_map.ValidCTileIndex(
1515 block_work_idx,
1516 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1517 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1518 {
1519 return;
1520 }
1521
1522 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1523 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1524
1525 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1526 const index_t m_block_data_idx_on_grid =
1527 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1528
1529 const index_t n_block_data_idx_on_grid =
1530 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1531
1532 // lds max alignment
1533 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1534
1535 // A matrix in LDS memory, dst of blockwise copy
1536 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1537
1538 // B matrix in LDS memory, dst of blockwise copy
1539 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1540
1541 auto get_a_blockwise_copy = [&]() {
1542 if constexpr(DirectLoad)
1543 {
1547 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1548 ABlockTransferThreadClusterArrangeOrder,
1549 ADataType,
1550 ADataType,
1551 decltype(a_grid_desc_ak0_m_ak1),
1552 decltype(a_block_desc_ak0_m_ak1),
1553 ABlockTransferSrcAccessOrder,
1554 ABlockTransferSrcVectorDim,
1555 2,
1556 ABlockTransferSrcScalarPerVector>(
1557 a_grid_desc_ak0_m_ak1,
1558 make_multi_index(0, m_block_data_idx_on_grid, 0),
1559 a_block_desc_ak0_m_ak1,
1560 make_multi_index(0, 0, 0));
1561 }
1562 else
1563 {
1566 AElementwiseOperation,
1570 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1571 ABlockTransferThreadClusterArrangeOrder,
1572 ADataType,
1573 LDSTypeA,
1574 decltype(a_grid_desc_ak0_m_ak1),
1575 decltype(a_block_desc_ak0_m_ak1),
1576 ABlockTransferSrcAccessOrder,
1578 ABlockTransferSrcVectorDim,
1579 2,
1580 ABlockTransferSrcScalarPerVector,
1581 ABlockTransferDstScalarPerVector_AK1,
1582 1,
1583 1,
1584 AThreadTransferSrcResetCoordinateAfterRun,
1585 true,
1586 BlockwiseGemmPipe::GlobalBufferNum>(
1587 a_grid_desc_ak0_m_ak1,
1588 make_multi_index(0, m_block_data_idx_on_grid, 0),
1589 a_element_op,
1590 a_block_desc_ak0_m_ak1,
1591 make_multi_index(0, 0, 0),
1593 }
1594 };
1595
1596 // B matrix blockwise copy
1597 auto get_b_blockwise_copy = [&]() {
1598 if constexpr(DirectLoad)
1599 {
1603 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1604 BBlockTransferThreadClusterArrangeOrder,
1605 BDataType,
1606 BDataType,
1607 decltype(b_grid_desc_bk0_n_bk1),
1608 decltype(b_block_desc_bk0_n_bk1),
1609 BBlockTransferSrcAccessOrder,
1610 BBlockTransferSrcVectorDim,
1611 2,
1612 BBlockTransferSrcScalarPerVector>(
1613 b_grid_desc_bk0_n_bk1,
1614 make_multi_index(0, n_block_data_idx_on_grid, 0),
1615 b_block_desc_bk0_n_bk1,
1616 make_multi_index(0, 0, 0));
1617 }
1618 else
1619 {
1622 BElementwiseOperation,
1626 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1627 BBlockTransferThreadClusterArrangeOrder,
1628 BDataType,
1629 LDSTypeB,
1630 decltype(b_grid_desc_bk0_n_bk1),
1631 decltype(b_block_desc_bk0_n_bk1),
1632 BBlockTransferSrcAccessOrder,
1634 BBlockTransferSrcVectorDim,
1635 2,
1636 BBlockTransferSrcScalarPerVector,
1637 BBlockTransferDstScalarPerVector_BK1,
1638 1,
1639 1,
1640 BThreadTransferSrcResetCoordinateAfterRun,
1641 true,
1642 BlockwiseGemmPipe::GlobalBufferNum>(
1643 b_grid_desc_bk0_n_bk1,
1644 make_multi_index(0, n_block_data_idx_on_grid, 0),
1645 b_element_op,
1646 b_block_desc_bk0_n_bk1,
1647 make_multi_index(0, 0, 0),
1649 }
1650 };
1651
1652 auto a_blockwise_copy = get_a_blockwise_copy();
1653 auto b_blockwise_copy = get_b_blockwise_copy();
1654
1655 // LDS allocation for A and B: be careful of alignment
1656 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1657 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1658
1659 // Cast after lds
1661 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1662
1664 static_cast<LDSTypeB*>(p_shared) +
1665 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
1666 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1667
1668 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1669 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1670
1671 // Blockwise GEMM pipeline
1672 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1673 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1674 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1675
1676 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1677 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1678 KPerBlock);
1679
1680 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1681 a_block_desc_ak0_m_ak1,
1682 a_blockwise_copy,
1683 a_grid_buf,
1684 a_block_buf,
1685 a_block_slice_copy_step,
1686 b_grid_desc_bk0_n_bk1,
1687 b_block_desc_bk0_n_bk1,
1688 b_blockwise_copy,
1689 b_grid_buf,
1690 b_block_buf,
1691 b_block_slice_copy_step,
1692 c_thread_buf,
1693 num_k_block_main_loop);
1694
1695 // shuffle C and write out
1696 {
1697 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1698 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1699 "wrong!");
1700
1701 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1702 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1703
1704 // TODO: hacky, fix it!
1705 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1706 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1707
1708 // TODO: hacky, fix it!
1709 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1710 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1711 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1712
1713 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1714 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1715 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1716 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1717 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1718 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1719 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1720 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1721
1722 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1724
1725 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1726 static_cast<CShuffleDataType*>(p_shared),
1727 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1728
1729 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1730 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1731 make_tuple(
1734 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1735 M1, // M1 = MWave
1736 M2, // M2 * M3 * M4 = MPerXdl
1737 M3,
1738 M4)),
1741 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1742 N1, // N1 = NWave
1743 N2))), // N2 = NPerXdl
1745 make_tuple(
1747
1748 // calculate origin of thread output tensor on global memory
1749 // blockwise GEMM c matrix starting index
1750 const auto c_thread_mtx_on_block =
1751 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1752
1753 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1754 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1755
1756 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1758 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1761
1762 const auto m_thread_data_on_block_idx =
1763 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1764 make_multi_index(m_thread_data_on_block));
1765
1766 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1771
1772 const auto n_thread_data_on_block_idx =
1773 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1774 make_multi_index(n_thread_data_on_block));
1775
1777 const auto& vpgr_to_lds_element_op = [&] {
1778 if constexpr(DoElementwiseBeforeCShuffle)
1779 {
1780 return c_element_op;
1781 }
1782 else
1783 {
1784 return pass_through;
1785 }
1786 };
1787 const auto& lds_to_global_element_op = [&] {
1788 if constexpr(!DoElementwiseBeforeCShuffle)
1789 {
1790 return c_element_op;
1791 }
1792 else
1793 {
1794 return pass_through;
1795 }
1796 };
1797
1798 // shuffle: threadwise copy C from VGPR to LDS
1799 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1800 AccDataType,
1801 CShuffleDataType,
1802 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1803 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1804 conditional_t<DoElementwiseBeforeCShuffle,
1805 CElementwiseOperation,
1807 Sequence<CShuffleMXdlPerWavePerShuffle,
1808 CShuffleNXdlPerWavePerShuffle,
1809 I1,
1810 I1,
1811 M2,
1812 I1,
1813 M4,
1814 I1>,
1816 7,
1817 1,
1819 1,
1820 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1822 0,
1823 m_thread_data_on_block_idx[I1],
1824 n_thread_data_on_block_idx[I1],
1825 m_thread_data_on_block_idx[I2],
1826 m_thread_data_on_block_idx[I3],
1827 m_thread_data_on_block_idx[I4],
1828 n_thread_data_on_block_idx[I2]),
1829 vpgr_to_lds_element_op()};
1830
1831 using EDataType = CDataType;
1832
1833 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1835 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1836
1837 const auto ds_grid_buf = generate_tuple(
1838 [&](auto i) {
1840 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1841 },
1843
1844 // tuple of reference to C/Ds tensor descriptors
1845 const auto c_ds_desc_refs = concat_tuple_of_reference(
1846 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1847 generate_tie([&](auto i) -> const auto& // return type should be reference
1848 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1850
1851 // tuple of reference to C/Ds tensor descriptors
1852 const auto c_ds_buf_refs = concat_tuple_of_reference(
1853 tie(c_shuffle_block_buf),
1854 generate_tie([&](auto i) -> const auto& // return type should be reference
1855 { return ds_grid_buf[i]; },
1857
1858 // tuple of starting index of C/Ds blockwise copy
1859 const auto idx_c_ds_block_begin = container_concat(
1860 make_tuple(make_multi_index(0, 0, 0, 0)),
1862 [&](auto) {
1863 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1864 },
1866
1867 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1868 c_grid_desc_mblock_mperblock_nblock_nperblock;
1869
1870 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1871 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1872 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1873
1874 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
1876 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1878 decltype(c_ds_desc_refs),
1879 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1880 conditional_t<!DoElementwiseBeforeCShuffle,
1881 CElementwiseOperation,
1883 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1884 // support arbitray type
1885 Sequence<1,
1886 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1887 1,
1888 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1889 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1890 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1891 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1892 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1893 3, // index_t SrcVectorDim,
1894 3, // index_t DstVectorDim,
1895 CDEShuffleBlockTransferScalarPerVectors,
1900 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1901 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
1902 {c_ds_desc_refs,
1903 idx_c_ds_block_begin,
1904 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1905 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
1906 lds_to_global_element_op()};
1907
1908 // space filling curve for threadwise C in VGPR
1909 constexpr auto sfc_c_vgpr =
1912 Sequence<CShuffleMXdlPerWavePerShuffle,
1913 CShuffleNXdlPerWavePerShuffle,
1914 1,
1915 1,
1916 M2,
1917 1,
1918 M4,
1919 1>>{};
1920
1921 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1922
1923 // space filling curve for shuffled blockwise C/D/E
1924 constexpr auto sfc_cde_block =
1927 Sequence<1,
1928 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1929 1,
1930 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1931
1932 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1933
1934 static_for<0, num_access, 1>{}([&](auto access_id) {
1935 // make sure it's safe to write to LDS
1937
1938 // each thread write its data from VGPR to LDS
1939 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1940 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1941 c_thread_buf,
1942 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1943 c_shuffle_block_buf);
1944
1945 // make sure it's safe to read from LDS
1947
1948 // each block copy its data from LDS to global
1949 cde_block_copy_lds_and_global.Run(
1950 c_ds_desc_refs,
1951 c_ds_buf_refs,
1952 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1953 tie(c_grid_buf));
1954
1955 if constexpr(access_id < num_access - 1)
1956 {
1957 constexpr auto cde_lds_and_global_step =
1958 sfc_cde_block.GetForwardStep(access_id);
1959
1960 // move on Ds
1961 static_for<0, NumDTensor, 1>{}([&](auto i) {
1962 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1963 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1964 });
1965
1966 // move on E
1967 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1968 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1969 I0,
1970 cde_lds_and_global_step);
1971 }
1972 });
1973 }
1974 }
1975
1976 template <bool HasMainKBlockLoop,
1977 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1978 TailNumber TailNum = TailNumber::Odd>
1979 __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid,
1980 const BDataType* __restrict__ p_b_grid,
1981 DsGridPointer& p_ds_grid,
1982 CDataType* __restrict__ p_c_grid,
1983 void* __restrict__ p_shared_0,
1984 void* __restrict__ p_shared_1,
1985 const Problem& problem,
1986 AElementwiseOperation a_element_op,
1987 BElementwiseOperation b_element_op,
1988 CElementwiseOperation c_element_op)
1989 {
1990 // divide block work by [M, N]
1991 const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
1993 p_a_grid,
1994 p_b_grid,
1995 p_ds_grid,
1996 p_c_grid,
1997 p_shared_0,
1998 p_shared_1,
1999 problem,
2000 a_element_op,
2001 b_element_op,
2002 c_element_op,
2003 block_2_ctile_map);
2004 }
2005
2006 template <typename Block2CTileMap,
2007 bool HasMainKBlockLoop,
2008 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2009 TailNumber TailNum = TailNumber::Odd>
2010 __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid,
2011 const BDataType* __restrict__ p_b_grid,
2012 DsGridPointer& p_ds_grid,
2013 CDataType* __restrict__ p_c_grid,
2014 void* __restrict__ p_shared_0,
2015 void* __restrict__ p_shared_1,
2016 const Problem& problem,
2017 AElementwiseOperation a_element_op,
2018 BElementwiseOperation b_element_op,
2019 CElementwiseOperation c_element_op,
2020 const Block2CTileMap& block_2_ctile_map)
2021 {
2022 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2023 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2024 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2025 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2026
2027 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2028 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2029 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2030 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2031
2033 p_b_grid,
2034 p_ds_grid,
2035 p_c_grid,
2036 p_shared_0,
2037 p_shared_1,
2038 problem,
2039 a_element_op,
2040 b_element_op,
2041 c_element_op,
2042 block_2_ctile_map,
2043 a_grid_desc_ak0_m_ak1,
2044 b_grid_desc_bk0_n_bk1,
2045 ds_grid_desc_m_n,
2046 c_grid_desc_m_n);
2047 }
2048
2049 template <bool HasMainKBlockLoop,
2050 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2051 TailNumber TailNum,
2052 typename Block2CTileMap,
2053 typename AGridDesc_AK0_M_K1,
2054 typename BGridDesc_BK0_N_K1,
2055 typename DsGridDesc_M_N,
2056 typename CGridDesc_M_N>
2057 __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid,
2058 const BDataType* __restrict__ p_b_grid,
2059 DsGridPointer& p_ds_grid,
2060 CDataType* __restrict__ p_c_grid,
2061 void* __restrict__ p_shared_0,
2062 void* __restrict__ p_shared_1,
2063 const Problem& problem,
2064 [[maybe_unused]] AElementwiseOperation a_element_op,
2065 [[maybe_unused]] BElementwiseOperation b_element_op,
2066 CElementwiseOperation c_element_op,
2067 const Block2CTileMap& block_2_ctile_map,
2068 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
2069 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
2070 const DsGridDesc_M_N& ds_grid_desc_m_n,
2071 const CGridDesc_M_N& c_grid_desc_m_n)
2072 {
2073
2074 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2076 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2077
2078 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2079 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2080 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2081 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2083 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2084
2085 const auto block_work_idx =
2086 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
2087
2088 if(!block_2_ctile_map.ValidCTileIndex(
2089 block_work_idx,
2090 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
2091 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
2092 {
2093 return;
2094 }
2095
2096 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
2097 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
2098
2099 // HACK: this force m/n_block_data_idx_on_grid into SGPR
2100 const index_t m_block_data_idx_on_grid =
2101 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
2102
2103 const index_t n_block_data_idx_on_grid =
2104 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2105
2106 // lds max alignment
2107 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2108
2109 // A matrix in LDS memory, dst of blockwise copy
2110 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2111
2112 // B matrix in LDS memory, dst of blockwise copy
2113 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2114
2115 auto get_a_blockwise_copy = [&]() {
2116 if constexpr(DirectLoad)
2117 {
2121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2122 ABlockTransferThreadClusterArrangeOrder,
2123 ADataType,
2124 ADataType,
2125 decltype(a_grid_desc_ak0_m_ak1),
2126 decltype(a_block_desc_ak0_m_ak1),
2127 ABlockTransferSrcAccessOrder,
2128 ABlockTransferSrcVectorDim,
2129 2,
2130 ABlockTransferSrcScalarPerVector>(
2131 a_grid_desc_ak0_m_ak1,
2132 make_multi_index(0, m_block_data_idx_on_grid, 0),
2133 a_block_desc_ak0_m_ak1,
2134 make_multi_index(0, 0, 0));
2135 }
2136 else
2137 {
2140 AElementwiseOperation,
2144 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2145 ABlockTransferThreadClusterArrangeOrder,
2146 ADataType,
2147 LDSTypeA,
2148 decltype(a_grid_desc_ak0_m_ak1),
2149 decltype(a_block_desc_ak0_m_ak1),
2150 ABlockTransferSrcAccessOrder,
2152 ABlockTransferSrcVectorDim,
2153 2,
2154 ABlockTransferSrcScalarPerVector,
2155 ABlockTransferDstScalarPerVector_AK1,
2156 1,
2157 1,
2158 AThreadTransferSrcResetCoordinateAfterRun,
2159 true,
2160 BlockwiseGemmPipe::GlobalBufferNum>(
2161 a_grid_desc_ak0_m_ak1,
2162 make_multi_index(0, m_block_data_idx_on_grid, 0),
2163 a_element_op,
2164 a_block_desc_ak0_m_ak1,
2165 make_multi_index(0, 0, 0),
2167 }
2168 };
2169
2170 // B matrix blockwise copy
2171 auto get_b_blockwise_copy = [&]() {
2172 if constexpr(DirectLoad)
2173 {
2177 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2178 BBlockTransferThreadClusterArrangeOrder,
2179 BDataType,
2180 BDataType,
2181 decltype(b_grid_desc_bk0_n_bk1),
2182 decltype(b_block_desc_bk0_n_bk1),
2183 BBlockTransferSrcAccessOrder,
2184 BBlockTransferSrcVectorDim,
2185 2,
2186 BBlockTransferSrcScalarPerVector>(
2187 b_grid_desc_bk0_n_bk1,
2188 make_multi_index(0, n_block_data_idx_on_grid, 0),
2189 b_block_desc_bk0_n_bk1,
2190 make_multi_index(0, 0, 0));
2191 }
2192 else
2193 {
2196 BElementwiseOperation,
2200 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2201 BBlockTransferThreadClusterArrangeOrder,
2202 BDataType,
2203 LDSTypeB,
2204 decltype(b_grid_desc_bk0_n_bk1),
2205 decltype(b_block_desc_bk0_n_bk1),
2206 BBlockTransferSrcAccessOrder,
2208 BBlockTransferSrcVectorDim,
2209 2,
2210 BBlockTransferSrcScalarPerVector,
2211 BBlockTransferDstScalarPerVector_BK1,
2212 1,
2213 1,
2214 BThreadTransferSrcResetCoordinateAfterRun,
2215 true,
2216 BlockwiseGemmPipe::GlobalBufferNum>(
2217 b_grid_desc_bk0_n_bk1,
2218 make_multi_index(0, n_block_data_idx_on_grid, 0),
2219 b_element_op,
2220 b_block_desc_bk0_n_bk1,
2221 make_multi_index(0, 0, 0),
2223 }
2224 };
2225
2226 auto a_blockwise_copy = get_a_blockwise_copy();
2227 auto b_blockwise_copy = get_b_blockwise_copy();
2228
2229 // LDS allocation for A and B: be careful of alignment
2230 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2231 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2232
2233 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2234 static_cast<LDSTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2235
2236 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2237 static_cast<LDSTypeB*>(p_shared_0) +
2238 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
2239 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2240
2241 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2242 static_cast<LDSTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2243
2244 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2245 static_cast<LDSTypeB*>(p_shared_1) +
2246 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
2247 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2248
2249 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2250 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2251
2252 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2253 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2254
2255 // Blockwise GEMM pipeline
2256 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2257 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2258 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2259
2260 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2261 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2262 KPerBlock);
2263
2264 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2265 a_block_desc_ak0_m_ak1,
2266 a_blockwise_copy,
2267 a_grid_buf,
2268 a_block_bufs,
2269 a_block_slice_copy_step,
2270 b_grid_desc_bk0_n_bk1,
2271 b_block_desc_bk0_n_bk1,
2272 b_blockwise_copy,
2273 b_grid_buf,
2274 b_block_bufs,
2275 b_block_slice_copy_step,
2276 c_thread_buf,
2277 num_k_block_main_loop);
2278
2279 // shuffle C and write out
2280 {
2281 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2282 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2283 "wrong!");
2284
2285 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2286 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2287
2288 // TODO: hacky, fix it!
2289 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2290 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2291
2292 // TODO: hacky, fix it!
2293 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2294 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2295 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2296
2297 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2298 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2299 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2300 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2301 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2302 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2303 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2304 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2305
2306 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2308
2309 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2310 static_cast<CShuffleDataType*>(p_shared_0),
2311 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2312
2313 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2314 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2315 make_tuple(
2318 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2319 M1, // M1 = MWave
2320 M2, // M2 * M3 * M4 = MPerXdl
2321 M3,
2322 M4)),
2325 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2326 N1, // N1 = NWave
2327 N2))), // N2 = NPerXdl
2329 make_tuple(
2331
2332 // calculate origin of thread output tensor on global memory
2333 // blockwise GEMM c matrix starting index
2334 const auto c_thread_mtx_on_block =
2335 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2336
2337 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2338 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2339
2340 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2342 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2345
2346 const auto m_thread_data_on_block_idx =
2347 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2348 make_multi_index(m_thread_data_on_block));
2349
2350 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2355
2356 const auto n_thread_data_on_block_idx =
2357 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2358 make_multi_index(n_thread_data_on_block));
2359
2361 const auto& vpgr_to_lds_element_op = [&] {
2362 if constexpr(DoElementwiseBeforeCShuffle)
2363 {
2364 return c_element_op;
2365 }
2366 else
2367 {
2368 return pass_through;
2369 }
2370 };
2371 const auto& lds_to_global_element_op = [&] {
2372 if constexpr(!DoElementwiseBeforeCShuffle)
2373 {
2374 return c_element_op;
2375 }
2376 else
2377 {
2378 return pass_through;
2379 }
2380 };
2381
2382 // shuffle: threadwise copy C from VGPR to LDS
2383 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2384 AccDataType,
2385 CShuffleDataType,
2386 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2387 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2388 conditional_t<DoElementwiseBeforeCShuffle,
2389 CElementwiseOperation,
2391 Sequence<CShuffleMXdlPerWavePerShuffle,
2392 CShuffleNXdlPerWavePerShuffle,
2393 I1,
2394 I1,
2395 M2,
2396 I1,
2397 M4,
2398 I1>,
2400 7,
2401 1,
2403 1,
2404 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2406 0,
2407 m_thread_data_on_block_idx[I1],
2408 n_thread_data_on_block_idx[I1],
2409 m_thread_data_on_block_idx[I2],
2410 m_thread_data_on_block_idx[I3],
2411 m_thread_data_on_block_idx[I4],
2412 n_thread_data_on_block_idx[I2]),
2413 vpgr_to_lds_element_op()};
2414
2415 using EDataType = CDataType;
2416
2417 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2419 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2420
2421 const auto ds_grid_buf = generate_tuple(
2422 [&](auto i) {
2424 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2425 },
2427
2428 // tuple of reference to C/Ds tensor descriptors
2429 const auto c_ds_desc_refs = concat_tuple_of_reference(
2430 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2431 generate_tie([&](auto i) -> const auto& // return type should be reference
2432 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2434
2435 // tuple of reference to C/Ds tensor descriptors
2436 const auto c_ds_buf_refs = concat_tuple_of_reference(
2437 tie(c_shuffle_block_buf),
2438 generate_tie([&](auto i) -> const auto& // return type should be reference
2439 { return ds_grid_buf[i]; },
2441
2442 // tuple of starting index of C/Ds blockwise copy
2443 const auto idx_c_ds_block_begin = container_concat(
2444 make_tuple(make_multi_index(0, 0, 0, 0)),
2446 [&](auto) {
2447 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
2448 },
2450
2451 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2452 c_grid_desc_mblock_mperblock_nblock_nperblock;
2453
2454 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
2455 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2456 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2457
2458 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
2460 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2462 decltype(c_ds_desc_refs),
2463 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2464 conditional_t<!DoElementwiseBeforeCShuffle,
2465 CElementwiseOperation,
2467 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2468 // support arbitray type
2469 Sequence<1,
2470 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2471 1,
2472 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2473 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2474 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2475 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2476 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2477 3, // index_t SrcVectorDim,
2478 3, // index_t DstVectorDim,
2479 CDEShuffleBlockTransferScalarPerVectors,
2484 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2485 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
2486 {c_ds_desc_refs,
2487 idx_c_ds_block_begin,
2488 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2489 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
2490 lds_to_global_element_op()};
2491
2492 // space filling curve for threadwise C in VGPR
2493 constexpr auto sfc_c_vgpr =
2496 Sequence<CShuffleMXdlPerWavePerShuffle,
2497 CShuffleNXdlPerWavePerShuffle,
2498 1,
2499 1,
2500 M2,
2501 1,
2502 M4,
2503 1>>{};
2504
2505 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2506
2507 // space filling curve for shuffled blockwise C/D/E
2508 constexpr auto sfc_cde_block =
2511 Sequence<1,
2512 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2513 1,
2514 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2515
2516 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2517
2518 static_for<0, num_access, 1>{}([&](auto access_id) {
2519 // make sure it's safe to write to LDS
2521
2522 // each thread write its data from VGPR to LDS
2523 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2524 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2525 c_thread_buf,
2526 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2527 c_shuffle_block_buf);
2528
2529 // make sure it's safe to read from LDS
2531
2532 // each block copy its data from LDS to global
2533 cde_block_copy_lds_and_global.Run(
2534 c_ds_desc_refs,
2535 c_ds_buf_refs,
2536 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2537 tie(c_grid_buf));
2538
2539 if constexpr(access_id < num_access - 1)
2540 {
2541 constexpr auto cde_lds_and_global_step =
2542 sfc_cde_block.GetForwardStep(access_id);
2543
2544 // move on Ds
2545 static_for<0, NumDTensor, 1>{}([&](auto i) {
2546 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2547 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2548 });
2549
2550 // move on E
2551 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2552 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2553 I0,
2554 cde_lds_and_global_step);
2555 }
2556 });
2557 }
2558 }
2559};
2560
2561} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:40
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:75
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:686
AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:727
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:722
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:723
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:725
BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:728
CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:729
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:688
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:724
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:675
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:672
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:673
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:670
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:671
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:679
__host__ __device__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:629
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:669
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:668
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:680
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:667
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:681
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:676
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:674
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:666
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:656
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:678
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:677
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:765
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:764
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:734
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1389
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, false >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1090
static __device__ void Run_2Lds(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared_0, void *__restrict__ p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDesc_M_N &ds_grid_desc_m_n, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:2057
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDesc_M_N &ds_grid_desc_m_n, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1483
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1437
static __device__ void Run_2Lds(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared_0, void *__restrict__ p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1979
static __device__ void Run_2Lds(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared_0, void *__restrict__ p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:2010
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340