gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
18
19#define DEBUG_LOG 0
20
21namespace ck {
22
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29template <typename GridwiseGemm,
30 bool HasMainKBlockLoop,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 index_t MinimumOccupancy = 1,
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
36__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
37#endif
38 // __attribute__((amdgpu_waves_per_eu(1, 1)))
39 kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg)
40{
41#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
42 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
43 {
44 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
45
46 // Full K needed for matrix B
47 const index_t Kt = karg.K;
48
49 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
50
51 const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
52 const index_t k_id = blockIdx.z * num_k_per_block;
53
54 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
55 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
56 karg.p_b_grid,
57 karg.p_ds_grid,
58 karg.p_c_grid,
59 p_shared,
60 karg,
61 karg.a_element_op,
62 karg.b_element_op,
63 karg.c_element_op,
64 k_id,
65 Kt);
66 }
67#else
68 ignore = karg;
69#endif // end of if (defined(__gfx9__))
70}
71
72template <typename GridwiseGemm,
73 bool HasMainKBlockLoop,
74 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
75 index_t MinimumOccupancy = 1,
77__global__ void
78#if CK_USE_LAUNCH_BOUNDS
79__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
80#endif
81 // __attribute__((amdgpu_waves_per_eu(1, 1)))
82 kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
83{
84#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
85 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
86 {
87 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
88 __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89
90 // Full K needed for matrix B
91 const index_t Kt = karg.K;
92
93 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94
95 const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
96 const index_t k_id = blockIdx.z * num_k_per_block;
97
98 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
99 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100 karg.p_b_grid,
101 karg.p_ds_grid,
102 karg.p_c_grid,
103 p_shared,
104 p_shared1,
105 karg,
106 karg.a_element_op,
107 karg.b_element_op,
108 karg.c_element_op,
109 k_id,
110 Kt);
111 }
112#else
113 ignore = karg;
114#endif // end of if (defined(__gfx9__))
115}
116
117template <typename ALayout,
118 typename BLayout,
119 typename DsLayout,
120 typename CLayout,
121 typename ADataType,
122 typename BDataType,
123 typename AccDataType,
124 typename CShuffleDataType,
125 typename DsDataType,
126 typename CDataType,
127 typename AElementwiseOperation,
128 typename BElementwiseOperation,
129 typename CElementwiseOperation,
131 index_t BlockSize,
132 index_t MPerBlock,
133 index_t NPerBlock,
134 index_t KPerBlock,
135 index_t AK1Value,
136 index_t BK1Value,
137 index_t MPerXdl,
138 index_t NPerXdl,
139 index_t MXdlPerWave,
140 index_t NXdlPerWave,
141 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
142 typename ABlockTransferThreadClusterArrangeOrder,
143 typename ABlockTransferSrcAccessOrder,
144 index_t ABlockTransferSrcVectorDim,
145 index_t ABlockTransferSrcScalarPerVector,
146 index_t ABlockTransferDstScalarPerVector_AK1,
147 bool AThreadTransferSrcResetCoordinateAfterRun,
148 index_t ABlockLdsExtraM,
149 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
150 typename BBlockTransferThreadClusterArrangeOrder,
151 typename BBlockTransferSrcAccessOrder,
152 index_t BBlockTransferSrcVectorDim,
153 index_t BBlockTransferSrcScalarPerVector,
154 index_t BBlockTransferDstScalarPerVector_BK1,
155 bool BThreadTransferSrcResetCoordinateAfterRun,
156 index_t BBlockLdsExtraN,
157 index_t CShuffleMXdlPerWavePerShuffle,
158 index_t CShuffleNXdlPerWavePerShuffle,
159 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
160 typename CDEShuffleBlockTransferScalarPerVectors,
163 typename ComputeTypeA = CDataType,
164 typename ComputeTypeB = ComputeTypeA,
165 typename LDSTypeA = ADataType,
166 typename LDSTypeB = BDataType>
168{
169 static constexpr auto I0 = Number<0>{};
170 static constexpr auto I1 = Number<1>{};
171 static constexpr auto I2 = Number<2>{};
172 static constexpr auto I3 = Number<3>{};
173 static constexpr auto I4 = Number<4>{};
174 static constexpr auto I5 = Number<5>{};
175 static constexpr auto I6 = Number<6>{};
176 static constexpr auto I7 = Number<7>{};
177
179 CDEShuffleBlockTransferScalarPerVectors{}[I0];
180 // K1 should be Number<...>
181 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
182 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
183 static constexpr auto AK1Number = Number<AK1Value>{};
184 static constexpr auto BK1Number = Number<BK1Value>{};
185 static constexpr auto BlockSizeNumber = Number<BlockSize>{};
186
187 static constexpr index_t NumDTensor = DsDataType::Size();
188
189 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
190 static constexpr bool is_single_rate_mfma =
192 lcm_AK1_BK1 <= 4) ||
193 (is_same<ComputeTypeA, int8_t>::value && KPerBlock < 128) ||
194 (is_same<ComputeTypeA, f8_t>::value && KPerBlock < 128))
195 ? true
196 : false;
197 static constexpr auto is_scale_mfma = false;
198 static constexpr auto mfma = MfmaSelector<ComputeTypeA,
199 MPerXdl,
200 NPerXdl,
201 ComputeTypeA,
203 is_scale_mfma>{};
204 static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
205 static constexpr index_t KGroup = []() {
207 // On gfx950, we have a mfma that required 32 f8 elements as input,
208 // splited into 2 groups of 16 f8 elements.
209 // the 2 groups is not contiguous in the B preshuffed layout.
210 // and we do not want it to be contiguous in the B preshuffled layout
211 // because a memory instruction can only read 16 f8 elements at a time.
212 return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
213 else
214 return 1;
215 }();
216 static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
217 static constexpr index_t KPackPerGroup = KPack / KGroup;
218 static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;
219 static constexpr index_t NLane = NPerXdl;
220 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
221
222 static constexpr auto MakeDsGridPointer()
223 {
224 return generate_tuple(
225 [&](auto i) {
226 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
227
228 return static_cast<const DDataType*>(nullptr);
229 },
231 }
232
233 using DsGridPointer = decltype(MakeDsGridPointer());
234
236
237 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
238 {
239 return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch);
240 }
241
242 __host__ __device__ static auto CalculateMPadded(index_t M)
243 {
244 return math::integer_least_multiple(M, MPerBlock);
245 }
246
247 __host__ __device__ static auto CalculateNPadded(index_t N)
248 {
249 return math::integer_least_multiple(N, NPerBlock);
250 }
251
252 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
253 {
255 }
256 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
257 {
259 }
260
261 __host__ __device__ static auto CalculateKPadded(index_t K)
262 {
263 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
264 }
265
266 __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
267 {
268 auto K_t = K_Batch * KPerBlock;
269 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
270 }
271
272 __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
273 {
274 auto K_t = K_Batch * KPerBlock;
275 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
276 }
277
278 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
279 {
280 auto K_t = K_Batch * KPerBlock;
281 return (K + K_t - 1) / K_t * KPerBlock;
282 }
283
284 __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
285 {
286 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
287 auto K_t = K_Batch * KReadVec;
288 return (K + K_t - 1) / K_t * KReadVec;
289 }
290
291 __host__ __device__ static auto CalculateMBlock(index_t M)
292 {
293 return math::integer_divide_ceil(M, MPerBlock);
294 }
295
296 __host__ __device__ static auto CalculateNBlock(index_t N)
297 {
298 return math::integer_divide_ceil(N, NPerBlock);
299 }
300
301 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
302 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
303 {
304 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
305 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
306
308 TileDesc_K0_MN_K1{},
314 }
315
316 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
317 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
318 {
319 const auto a_grid_desc_mraw_kraw = [&]() {
321 {
322 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
323 }
325 {
326 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
327 }
328 }();
329
330 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
331
332 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
333 GemmSpec == GemmSpecialization::MNKPadding)
334 {
335 // pad both M and K
336 const auto a_grid_desc_m_k =
337 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
339 make_right_pad_transform(K, KPad - K)),
342
343 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
344 a_grid_desc_m_k,
349
350 return a_grid_desc_ak0_m_ak1;
351 }
352 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
353 GemmSpec == GemmSpecialization::MNPadding)
354 {
355 // pad M, but not K
356 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
357 a_grid_desc_mraw_kraw,
359 make_right_pad_transform(M, MPad - M)),
362
363 return a_grid_desc_ak0_m_ak1;
364 }
365 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
366 GemmSpec == GemmSpecialization::NKPadding)
367 {
368 // pad K, but not M
369 const auto a_grid_desc_m_k = transform_tensor_descriptor(
370 a_grid_desc_mraw_kraw,
374
375 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
376 a_grid_desc_m_k,
381
382 return a_grid_desc_ak0_m_ak1;
383 }
384 else
385 {
386 // not pad M or K
387 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
388 a_grid_desc_mraw_kraw,
393
394 return a_grid_desc_ak0_m_ak1;
395 }
396 }
397
398 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
399 {
400 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
401 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
402 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPackPerGroup>{};
404 make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
405 make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
406 }
407
408 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
409 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
410 {
411 const auto b_grid_desc_nraw_kraw = [&]() {
413 {
414 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
415 }
417 {
418 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
419 }
420 }();
421
422 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
423
424 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
425 GemmSpec == GemmSpecialization::MNKPadding)
426 {
427 // pad both N and K
428 const auto b_grid_desc_n_k =
429 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
431 make_right_pad_transform(K, KPad - K)),
434
435 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
436 b_grid_desc_n_k,
441
442 return b_grid_desc_bk0_n_bk1;
443 }
444 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
445 GemmSpec == GemmSpecialization::MNPadding)
446 {
447 // pad N, but not K
448 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
449 b_grid_desc_nraw_kraw,
451 make_right_pad_transform(N, NPad - N)),
454
455 return b_grid_desc_bk0_n_bk1;
456 }
457 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
458 GemmSpec == GemmSpecialization::MKPadding)
459 {
460 // pad K, but not N
461 const auto b_grid_desc_n_k = transform_tensor_descriptor(
462 b_grid_desc_nraw_kraw,
466
467 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
468 b_grid_desc_n_k,
473
474 return b_grid_desc_bk0_n_bk1;
475 }
476 else
477 {
478 // not pad N or K
479 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
480 b_grid_desc_nraw_kraw,
485
486 return b_grid_desc_bk0_n_bk1;
487 }
488 }
489
490 template <typename ABlockDesc_AK0_M_AK1>
491 __host__ __device__ static constexpr auto
492 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
493 {
494 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
495
496 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
497 }
498
499 template <typename BBlockDesc_BK0_N_BK1>
500 __host__ __device__ static constexpr auto
501 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
502 {
503 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
504 }
505
506 template <typename ELayout>
507 __host__ __device__ static auto
509 {
510 const auto c_grid_desc_mraw_nraw = [&]() {
512 {
513 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
514 }
516 {
517 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
518 }
519 }();
520
521 // pad M and N
522 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
524 make_right_pad_transform(N, NPad - N)),
527#if 0
528 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
529
530 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
531 GemmSpec == GemmSpecialization::MNKPadding)
532 {
533 // pad M and N
534 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
536 make_right_pad_transform(N, NPad - N)),
539 }
540 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
541 GemmSpec == GemmSpecialization::MKPadding)
542 {
543 // pad M, but not N
545 c_grid_desc_mraw_nraw,
549 }
550 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
551 GemmSpec == GemmSpecialization::NKPadding)
552 {
553 // pad N, but not M
555 c_grid_desc_mraw_nraw,
559 }
560 else
561 {
562 // not pad M or N
563 return c_grid_desc_mraw_nraw;
564 }
565#endif
566 }
567
568 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
569 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
570 {
571 return generate_tuple(
572 [&](auto i) {
573 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
574 return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
575 },
577 }
578
579 template <typename DsGridDesc>
581 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
582 {
583 return generate_tuple(
584 [&](auto i) {
586 ds_grid_desc_m_n[i], MBlock, NBlock);
587 },
589 }
590
591 using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
592
593 struct Problem
594 {
595 __host__ __device__ Problem(index_t M_,
596 index_t N_,
597 index_t K_,
598 index_t StrideA_,
599 index_t StrideB_,
600 std::array<index_t, NumDTensor> StrideDs_,
601 index_t StrideC_,
602 index_t KBatch_)
603 : M{M_},
604 N{N_},
605 K{K_},
606 StrideA{StrideA_},
607 StrideB{StrideB_},
608 StrideDs{StrideDs_},
609 StrideC{StrideC_},
610 KBatch{KBatch_},
613 KRead{CalculateKRead(K_, KBatch_)},
614 KPadded{CalculateKPadded(K_, KBatch_)},
615 AK0{CalculateAK0Padded(K_, KBatch_)},
616 BK0{CalculateBK0Padded(K_, KBatch_)},
619 {
620 }
621
622 __host__ void Print() const
623 {
624 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
625 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
626 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
627 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
628 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
629 << "NBlock: " << NBlock << " }" << std::endl;
630 }
631
637 std::array<index_t, NumDTensor> StrideDs;
648 };
649
650 // Argument
652 {
653 __host__ Argument(const ADataType* p_a_grid_,
654 const BDataType* p_b_grid_,
655 std::array<const void*, NumDTensor> p_ds_grid_,
656 CDataType* p_c_grid_,
657 index_t M_,
658 index_t N_,
659 index_t K_,
660 index_t StrideA_,
661 index_t StrideB_,
662 std::array<index_t, NumDTensor> StrideDs_,
663 index_t StrideC_,
664 index_t k_batch_,
665 AElementwiseOperation a_element_op_,
666 BElementwiseOperation b_element_op_,
667 CElementwiseOperation c_element_op_)
668 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
669 p_a_grid{p_a_grid_},
670 p_b_grid{p_b_grid_},
671 p_ds_grid{},
672 p_c_grid{p_c_grid_},
673 a_element_op{a_element_op_},
674 b_element_op{b_element_op_},
675 c_element_op{c_element_op_}
676 {
677
678 // populate pointer, desc for Ds
679 static_for<0, NumDTensor, 1>{}([&](auto i) {
680 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
681
682 // D pointer
683 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
684 });
685 }
686
687 const ADataType* p_a_grid;
688 const BDataType* p_b_grid;
690 CDataType* p_c_grid;
691
692 const AElementwiseOperation a_element_op;
693 const BElementwiseOperation b_element_op;
694 const CElementwiseOperation c_element_op;
695 };
696
698 {
699 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
700 {
702 {
703 a_k_split_offset = k_id * karg.KRead;
704 }
706 {
707 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
708 }
709
710 if(k_id < karg.KBatch - 1)
711 {
712 karg.K = karg.KRead;
713 }
714 else
715 {
716 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
717 }
718 }
719
721 };
722
723 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
724 {
725 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
726 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
727 // A matrix in LDS memory, dst of blockwise copy
728 if constexpr(ABlockLdsExtraM)
729 {
733 }
734 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
735 // in some cases.
737 {
738 constexpr auto a_lds_block_desc =
741
742 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
743 a_lds_block_desc,
749
750 return a_lds_block_desc_permuted;
751 }
752 else // ColumnMajor A
753 {
754 // kfold and mpair dimension is not always required.
755 // more dimension in merge_transform increase the difficulty of generating immarg offset
756 // for compiler.
757 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
758 constexpr auto M1 = MPerBlock / M0;
759
760 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
761 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
762 constexpr auto KThreadRead = WaveSize / MPerXdl;
763 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
764
765 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
766 ? 1
767 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
768 constexpr auto KThreadReadPerm =
769 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
770 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
771 : KThreadRead;
772
773 // 1<=mpair<=n0
774 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
775 ? 1
776 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
777 ? M0
778 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
779
780 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
784 Number<kfold * M0 / mpair>{},
786 AK1Number));
787
788 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
789 a_lds_block_desc,
794 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
801
802 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
803 a_lds_block_desc_permuted,
812 Sequence<1>{},
813 Sequence<2>{},
814 Sequence<3>{},
815 Sequence<4>{},
816 Sequence<5>{}),
818 Sequence<2>{},
821 Sequence<6>{},
822 Sequence<7>{}));
823
824 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
825 a_lds_block_desc_unmerged,
828 Number<KThreadWrite / kfold / KThreadReadPerm>{},
836
837 return a_lds_block_desc_ak0_m_ak1;
838 }
839 }
840
841 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
842 {
843 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
846 }
847
849 {
850 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
851
852 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
856 I1,
858
859 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
860 }
861
864 BlkGemmPipelineVer,
865 BlkGemmPipeSched,
866 BlockSize,
867 LDSTypeA,
868 LDSTypeB,
869 ComputeTypeA,
870 AccDataType,
877 ABlockTransferSrcScalarPerVector,
878 BBlockTransferSrcScalarPerVector,
879 MPerBlock,
880 NPerBlock,
881 KPerBlock,
882 MPerXdl,
883 NPerXdl,
884 MXdlPerWave,
885 NXdlPerWave,
886 KPack>())>;
887
888 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
889 {
890 // LDS allocation for A and B: be careful of alignment
891 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
892 // lds max alignment
893 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
894
895 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
896 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
897
898 // LDS allocation for C shuffle in LDS
899 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
901
902 constexpr auto c_block_size =
903 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
904
905 return math::max(a_block_space_size_aligned * sizeof(LDSTypeA),
906 c_block_size * sizeof(CShuffleDataType));
907 }
908
909 template <
910 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
911 __device__ static bool constexpr IsValidCompilationParameter()
912 {
913 constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter<
914 BlockSize,
915 MPerBlock,
916 NPerBlock,
917 MPerXdl,
918 NPerXdl,
919 MXdlPerWave,
920 NXdlPerWave,
921 CDataType,
922 CGlobalMemoryDataOperation_>();
923 if constexpr(!valid)
924 {
925 return false;
926 }
927
928 using MfmaInst = MfmaSelector<ComputeTypeA,
929 MPerXdl,
930 NPerXdl,
931 ComputeTypeB,
934
935 constexpr index_t KPerThread =
936 KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops());
937 if constexpr(KPerThread % KPack != 0)
938 {
939 return false;
940 }
941
942 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
943 {
944 return false;
945 }
946 return true;
947 }
948
949 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
950 __host__ static constexpr bool CheckValidity(const Argument& karg)
951 {
952 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
953 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
954 "Invalid tuning param!");
955
956 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
957 {
958 return false;
959 }
960
966 {
967 if(!(karg.M % MPerBlock == 0))
968 {
969#if DEBUG_LOG
970 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
971 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
972 << std::endl;
973
974#endif // DEBUG_LOG
975 return false;
976 }
977 }
978
984 {
985 if(!(karg.N % NPerBlock == 0))
986 {
987#if DEBUG_LOG
988 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
989 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
990 << std::endl;
991
992#endif // DEBUG_LOG
993 return false;
994 }
995 }
996
1001 {
1002
1003 auto K_t = karg.KBatch * KPerBlock;
1004 if(!(karg.K % K_t == 0))
1005 {
1006#if DEBUG_LOG
1007 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1008 << karg.K << " " << __FILE__ << ":" << __LINE__
1009 << ", in function: " << __func__ << std::endl;
1010
1011#endif // DEBUG_LOG
1012 return false;
1013 }
1014 }
1015 else
1016 {
1017 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1018 auto K_t = karg.KBatch * KReadVec;
1019 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1020 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1021 {
1022 return false;
1023 }
1024 }
1025
1027 {
1028 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1029 {
1030#if DEBUG_LOG
1031 std::cout << "Arg K (" << karg.K
1032 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1033 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1034 << __LINE__ << ", in function: " << __func__ << std::endl;
1035
1036#endif // DEBUG_LOG
1037 return false;
1038 }
1039 }
1040 else
1041 {
1042 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1043 {
1044#if DEBUG_LOG
1045 std::cout << "Arg M (" << karg.M
1046 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1047 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1048 << __LINE__ << ", in function: " << __func__ << std::endl;
1049
1050#endif // DEBUG_LOG
1051 return false;
1052 }
1053 }
1054
1056 {
1057 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1058 {
1059#if DEBUG_LOG
1060 std::cout << "Arg N (" << karg.N
1061 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1062 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1063 << __LINE__ << ", in function: " << __func__ << std::endl;
1064
1065#endif // DEBUG_LOG
1066 return false;
1067 }
1068 }
1069 else
1070 {
1071 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1072 {
1073#if DEBUG_LOG
1074 std::cout << "Arg K (" << karg.K
1075 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1076 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1077 << __LINE__ << ", in function: " << __func__ << std::endl;
1078
1079#endif // DEBUG_LOG
1080 return false;
1081 }
1082 }
1083
1085 {
1087 {
1088#if DEBUG_LOG
1089 std::cout << "Arg N (" << karg.N
1090 << ") value is not a multiple of "
1091 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1092 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1093 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1094
1095#endif // DEBUG_LOG
1096 return false;
1097 }
1098 }
1099 else
1100 {
1102 {
1103#if DEBUG_LOG
1104 std::cout << "Arg M (" << karg.M
1105 << ") value is not a multiple of "
1106 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1107 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1108 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1109
1110#endif // DEBUG_LOG
1111 return false;
1112 }
1113 }
1114
1115 // check gridwise gemm pipeline
1116#if 1
1117 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1118
1119 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1120 {
1121 return false;
1122 }
1123#endif
1124 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1125 return true;
1126 }
1127
1128 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1129 {
1130 const index_t num_loop = K / KPerBlock;
1131
1132 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1133 }
1134
1135 __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1136 {
1137 const index_t num_loop = K / KPerBlock;
1138
1139 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1140 }
1141
1142 template <typename CGridDesc>
1144 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1145 {
1146 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1147 c_grid_desc_m_n,
1152
1153 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1154 }
1155
1156 // return block_id to C matrix tile idx (m0, n0) mapping
1157 // if arch = gfx942
1159
1160 template <bool HasMainKBlockLoop,
1161 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1162 TailNumber TailNum = TailNumber::Odd>
1163 __device__ static void Run(const ADataType* p_a_grid,
1164 const BDataType* p_b_grid,
1165 DsGridPointer& p_ds_grid,
1166 CDataType* p_c_grid,
1167 void* p_shared,
1168 const Problem& problem,
1169 AElementwiseOperation a_element_op,
1170 BElementwiseOperation b_element_op,
1171 CElementwiseOperation c_element_op,
1172 const index_t k_id,
1173 const index_t Kt)
1174 {
1175 const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
1177 p_a_grid,
1178 p_b_grid,
1179 p_ds_grid,
1180 p_c_grid,
1181 p_shared,
1182 problem,
1183 a_element_op,
1184 b_element_op,
1185 c_element_op,
1186 block_2_ctile_map,
1187 k_id,
1188 Kt);
1189 }
1190
1191 template <typename Block2CTileMap,
1192 bool HasMainKBlockLoop,
1193 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1194 TailNumber TailNum = TailNumber::Odd>
1195 __device__ static void Run(const ADataType* p_a_grid,
1196 const BDataType* p_b_grid,
1197 DsGridPointer& p_ds_grid,
1198 CDataType* p_c_grid,
1199 void* p_shared,
1200 const Problem& problem,
1201 AElementwiseOperation a_element_op,
1202 BElementwiseOperation b_element_op,
1203 CElementwiseOperation c_element_op,
1204 const Block2CTileMap& block_2_ctile_map,
1205 const index_t k_id,
1206 const index_t Kt)
1207 {
1208 ignore = b_element_op;
1209 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1210 index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
1211
1212 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1213 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1214
1215 const auto b_grid_desc_bpreshuffled =
1216 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1217 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1218 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1219
1220 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1222 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1223
1224 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1225 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1226 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1227 p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1229 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1230
1231 const auto block_work_idx =
1232 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1233
1234 if(!block_2_ctile_map.ValidCTileIndex(
1235 block_work_idx,
1236 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1237 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1238 {
1239 return;
1240 }
1241
1242 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1243 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1244
1245 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1246 const index_t m_block_data_idx_on_grid =
1247 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1248
1249 const index_t n_block_data_idx_on_grid =
1250 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1251
1252 // A matrix in LDS memory, dst of blockwise copy
1253 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1254
1255 // B matrix in LDS memory, dst of blockwise copy
1256 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1257
1258 // A matrix blockwise copy
1259 auto a_blockwise_copy =
1261 AElementwiseOperation,
1265 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1266 ABlockTransferThreadClusterArrangeOrder,
1267 ADataType,
1268 LDSTypeA,
1269 decltype(a_grid_desc_ak0_m_ak1),
1270 decltype(a_block_desc_ak0_m_ak1),
1271 ABlockTransferSrcAccessOrder,
1273 ABlockTransferSrcVectorDim,
1274 2,
1275 ABlockTransferSrcScalarPerVector,
1276 ABlockTransferDstScalarPerVector_AK1,
1277 1,
1278 1,
1279 AThreadTransferSrcResetCoordinateAfterRun,
1280 true,
1281 BlockwiseGemmPipe::GlobalBufferNum>(
1282 a_grid_desc_ak0_m_ak1,
1283 make_multi_index(0, m_block_data_idx_on_grid, 0),
1284 a_element_op,
1285 a_block_desc_ak0_m_ak1,
1286 make_multi_index(0, 0, 0),
1288
1289 // Thread-wise copy
1290 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1292 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1293
1294 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1295 BDataType,
1296 BDataType,
1297 decltype(b_grid_desc_bpreshuffled),
1298 decltype(b_block_desc_bk0_n_bk1),
1301 3,
1302 BBlockTransferSrcScalarPerVector,
1303 BThreadTransferSrcResetCoordinateAfterRun,
1304 true>(b_grid_desc_bpreshuffled,
1305 make_multi_index(n_block_data_idx_on_grid,
1307 k_id,
1308 KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
1309
1310 // LDS allocation for A and B: be careful of alignment
1311 // Cast after lds
1313 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1314
1315 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1316 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1317
1318 // Blockwise GEMM pipeline
1319 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1320 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1321 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1322
1323 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1324 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1325 KPerBlock);
1326
1327 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1328 a_block_desc_ak0_m_ak1,
1329 a_blockwise_copy,
1330 a_grid_buf,
1331 a_block_buf,
1332 a_block_slice_copy_step,
1333 b_grid_desc_bpreshuffled,
1334 b_blockwise_copy,
1335 b_grid_buf,
1336 b_block_buf,
1337 b_block_slice_copy_step,
1338 c_thread_buf,
1339 num_k_block_main_loop);
1340
1341 // shuffle C and write out
1342 {
1343 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1344 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1345 "wrong!");
1346
1347 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1348
1349 // TODO: hacky, fix it!
1350 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1351 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1352
1353 // TODO: hacky, fix it!
1354 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1355 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1356 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1357
1358 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1359 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1360 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1361 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1362 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1363 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1364 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1365 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1366
1367 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1369
1370 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1371 static_cast<CShuffleDataType*>(p_shared),
1372 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1373
1374 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1375 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1376 make_tuple(
1379 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1380 M1, // M1 = MWave
1381 M2, // M2 * M3 * M4 = MPerXdl
1382 M3,
1383 M4)),
1386 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1387 N1, // N1 = NWave
1388 N2))), // N2 = NPerXdl
1390 make_tuple(
1392
1393 // calculate origin of thread output tensor on global memory
1394 // blockwise GEMM c matrix starting index
1395 const auto c_thread_mtx_on_block =
1396 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1397
1398 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1399 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1400
1401 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1403 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1406
1407 const auto m_thread_data_on_block_idx =
1408 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1409 make_multi_index(m_thread_data_on_block));
1410
1411 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1416
1417 const auto n_thread_data_on_block_idx =
1418 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1419 make_multi_index(n_thread_data_on_block));
1420
1421 // shuffle: threadwise copy C from VGPR to LDS
1422 auto c_thread_copy_vgpr_to_lds =
1424 CShuffleDataType,
1425 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1426 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1428 Sequence<CShuffleMXdlPerWavePerShuffle,
1429 CShuffleNXdlPerWavePerShuffle,
1430 I1,
1431 I1,
1432 M2,
1433 I1,
1434 M4,
1435 I1>,
1437 7,
1438 1,
1440 1,
1441 true>{
1442 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1444 0,
1445 m_thread_data_on_block_idx[I1],
1446 n_thread_data_on_block_idx[I1],
1447 m_thread_data_on_block_idx[I2],
1448 m_thread_data_on_block_idx[I3],
1449 m_thread_data_on_block_idx[I4],
1450 n_thread_data_on_block_idx[I2]),
1452
1453 using EDataType = CDataType;
1454
1455 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1456 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1457
1458 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1460 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1461
1462 const auto ds_grid_buf = generate_tuple(
1463 [&](auto i) {
1465 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1466 },
1468
1469 // tuple of reference to C/Ds tensor descriptors
1470 const auto c_ds_desc_refs = concat_tuple_of_reference(
1471 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1472 generate_tie([&](auto i) -> const auto& // return type should be reference
1473 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1475
1476 // tuple of reference to C/Ds tensor descriptors
1477 const auto c_ds_buf_refs = concat_tuple_of_reference(
1478 tie(c_shuffle_block_buf),
1479 generate_tie([&](auto i) -> const auto& // return type should be reference
1480 { return ds_grid_buf[i]; },
1482
1483 // tuple of starting index of C/Ds blockwise copy
1484 const auto idx_c_ds_block_begin = container_concat(
1485 make_tuple(make_multi_index(0, 0, 0, 0)),
1487 [&](auto) {
1488 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1489 },
1491
1492 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1493 c_grid_desc_mblock_mperblock_nblock_nperblock;
1494
1495 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1496 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1497 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1498
1499 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
1501 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1503 decltype(c_ds_desc_refs),
1504 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1505 CElementwiseOperation,
1506 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1507 // support arbitray type
1508 Sequence<1,
1509 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1510 1,
1511 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1512 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1513 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1514 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1515 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1516 3, // index_t SrcVectorDim,
1517 3, // index_t DstVectorDim,
1518 CDEShuffleBlockTransferScalarPerVectors,
1523 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1524 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
1525 {c_ds_desc_refs,
1526 idx_c_ds_block_begin,
1527 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1528 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
1529 c_element_op};
1530
1531 // space filling curve for threadwise C in VGPR
1532 constexpr auto sfc_c_vgpr =
1535 Sequence<CShuffleMXdlPerWavePerShuffle,
1536 CShuffleNXdlPerWavePerShuffle,
1537 1,
1538 1,
1539 M2,
1540 1,
1541 M4,
1542 1>>{};
1543
1544 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1545
1546 // space filling curve for shuffled blockwise C/D/E
1547 constexpr auto sfc_cde_block =
1550 Sequence<1,
1551 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1552 1,
1553 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1554
1555 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1556
1557 static_for<0, num_access, 1>{}([&](auto access_id) {
1558 // make sure it's safe to write to LDS
1560
1561 // each thread write its data from VGPR to LDS
1562 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1563 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1564 c_thread_buf,
1565 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1566 c_shuffle_block_buf);
1567
1568 // make sure it's safe to read from LDS
1570
1571 // each block copy its data from LDS to global
1572 cde_block_copy_lds_and_global.Run(
1573 c_ds_desc_refs,
1574 c_ds_buf_refs,
1575 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1576 tie(c_grid_buf));
1577
1578 if constexpr(access_id < num_access - 1)
1579 {
1580 constexpr auto cde_lds_and_global_step =
1581 sfc_cde_block.GetForwardStep(access_id);
1582
1583 // move on Ds
1584 static_for<0, NumDTensor, 1>{}([&](auto i) {
1585 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1586 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1587 });
1588
1589 // move on E
1590 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1591 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1592 I0,
1593 cde_lds_and_global_step);
1594 }
1595 });
1596 }
1597 }
1598
1599 template <bool HasMainKBlockLoop,
1600 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1601 TailNumber TailNum = TailNumber::Odd>
1602 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1603 const BDataType* p_b_grid,
1604 DsGridPointer& p_ds_grid,
1605 CDataType* p_c_grid,
1606 void* p_shared,
1607 void* p_shared1,
1608 const Problem& problem,
1609 AElementwiseOperation a_element_op,
1610 BElementwiseOperation b_element_op,
1611 CElementwiseOperation c_element_op,
1612 const index_t k_id,
1613 const index_t Kt)
1614 {
1615 const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
1617 p_a_grid,
1618 p_b_grid,
1619 p_ds_grid,
1620 p_c_grid,
1621 p_shared,
1622 p_shared1,
1623 problem,
1624 a_element_op,
1625 b_element_op,
1626 c_element_op,
1627 block_2_ctile_map,
1628 k_id,
1629 Kt);
1630 }
1631
1632 template <typename Block2CTileMap,
1633 bool HasMainKBlockLoop,
1634 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1635 TailNumber TailNum = TailNumber::Odd>
1636 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1637 const BDataType* p_b_grid,
1638 DsGridPointer& p_ds_grid,
1639 CDataType* p_c_grid,
1640 void* p_shared,
1641 void* p_shared1,
1642 const Problem& problem,
1643 AElementwiseOperation a_element_op,
1644 BElementwiseOperation b_element_op,
1645 CElementwiseOperation c_element_op,
1646 const Block2CTileMap& block_2_ctile_map,
1647 const index_t k_id,
1648 const index_t Kt)
1649 {
1650 ignore = b_element_op;
1651 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1652 index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
1653 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1654 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1655
1656 const auto b_grid_desc_bpreshuffled =
1657 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1658 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1659 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1660
1661 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1663 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1664
1665 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1666 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1667 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1668 p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1670 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1671
1672 const auto block_work_idx =
1673 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1674
1675 if(!block_2_ctile_map.ValidCTileIndex(
1676 block_work_idx,
1677 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1678 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1679 {
1680 return;
1681 }
1682
1683 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1684 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1685
1686 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1687 const index_t m_block_data_idx_on_grid =
1688 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1689
1690 const index_t n_block_data_idx_on_grid =
1691 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1692
1693 // A matrix in LDS memory, dst of blockwise copy
1694 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1695
1696 // B matrix in LDS memory, dst of blockwise copy
1697 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1698
1699 // A matrix blockwise copy
1700 auto a_blockwise_copy =
1702 AElementwiseOperation,
1706 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1707 ABlockTransferThreadClusterArrangeOrder,
1708 ADataType,
1709 LDSTypeA,
1710 decltype(a_grid_desc_ak0_m_ak1),
1711 decltype(a_block_desc_ak0_m_ak1),
1712 ABlockTransferSrcAccessOrder,
1714 ABlockTransferSrcVectorDim,
1715 2,
1716 ABlockTransferSrcScalarPerVector,
1717 ABlockTransferDstScalarPerVector_AK1,
1718 1,
1719 1,
1720 AThreadTransferSrcResetCoordinateAfterRun,
1721 true,
1722 2>(
1723 a_grid_desc_ak0_m_ak1,
1724 make_multi_index(0, m_block_data_idx_on_grid, 0),
1725 a_element_op,
1726 a_block_desc_ak0_m_ak1,
1727 make_multi_index(0, 0, 0),
1729
1730 // Thread-wise copy
1731 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1733 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1735 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1736 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1737
1738 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1739 BDataType,
1740 BDataType,
1741 decltype(b_grid_desc_bpreshuffled),
1742 decltype(b_block_desc_bk0_n_bk1),
1745 3,
1746 BBlockTransferSrcScalarPerVector,
1747 BThreadTransferSrcResetCoordinateAfterRun,
1748 true>(b_grid_desc_bpreshuffled,
1749 make_multi_index(n_block_data_idx_on_grid,
1751 k_id,
1752 KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
1753
1754 // LDS allocation for A and B: be careful of alignment
1755 // Cast after lds
1756 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1757 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1758 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1759 static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1760 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1761
1762 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1763 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1764
1765 // Blockwise GEMM pipeline
1766 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1767 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1768 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1769
1770 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1771 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1772 KPerBlock);
1773
1774 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1775 a_block_desc_ak0_m_ak1,
1776 a_blockwise_copy,
1777 a_grid_buf,
1778 a_block_bufs,
1779 a_block_slice_copy_step,
1780 b_grid_desc_bpreshuffled,
1781 b_blockwise_copy,
1782 b_grid_buf,
1783 b_block_bufs,
1784 b_block_slice_copy_step,
1785 c_thread_buf,
1786 num_k_block_main_loop);
1787
1788 // shuffle C and write out
1789 {
1790 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1791 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1792 "wrong!");
1793
1794 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1795
1796 // TODO: hacky, fix it!
1797 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1798 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1799
1800 // TODO: hacky, fix it!
1801 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1802 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1803 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1804
1805 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1806 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1807 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1808 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1809 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1810 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1811 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1812 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1813
1814 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1816
1817 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1818 static_cast<CShuffleDataType*>(p_shared),
1819 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1820
1821 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1822 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1823 make_tuple(
1826 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1827 M1, // M1 = MWave
1828 M2, // M2 * M3 * M4 = MPerXdl
1829 M3,
1830 M4)),
1833 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1834 N1, // N1 = NWave
1835 N2))), // N2 = NPerXdl
1837 make_tuple(
1839
1840 // calculate origin of thread output tensor on global memory
1841 // blockwise GEMM c matrix starting index
1842 const auto c_thread_mtx_on_block =
1843 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1844
1845 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1846 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1847
1848 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1850 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1853
1854 const auto m_thread_data_on_block_idx =
1855 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1856 make_multi_index(m_thread_data_on_block));
1857
1858 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1863
1864 const auto n_thread_data_on_block_idx =
1865 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1866 make_multi_index(n_thread_data_on_block));
1867
1868 // shuffle: threadwise copy C from VGPR to LDS
1869 auto c_thread_copy_vgpr_to_lds =
1871 CShuffleDataType,
1872 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1873 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1875 Sequence<CShuffleMXdlPerWavePerShuffle,
1876 CShuffleNXdlPerWavePerShuffle,
1877 I1,
1878 I1,
1879 M2,
1880 I1,
1881 M4,
1882 I1>,
1884 7,
1885 1,
1887 1,
1888 true>{
1889 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1891 0,
1892 m_thread_data_on_block_idx[I1],
1893 n_thread_data_on_block_idx[I1],
1894 m_thread_data_on_block_idx[I2],
1895 m_thread_data_on_block_idx[I3],
1896 m_thread_data_on_block_idx[I4],
1897 n_thread_data_on_block_idx[I2]),
1899
1900 using EDataType = CDataType;
1901
1902 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1903 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1904
1905 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1907 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1908
1909 const auto ds_grid_buf = generate_tuple(
1910 [&](auto i) {
1912 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1913 },
1915
1916 // tuple of reference to C/Ds tensor descriptors
1917 const auto c_ds_desc_refs = concat_tuple_of_reference(
1918 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1919 generate_tie([&](auto i) -> const auto& // return type should be reference
1920 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1922
1923 // tuple of reference to C/Ds tensor descriptors
1924 const auto c_ds_buf_refs = concat_tuple_of_reference(
1925 tie(c_shuffle_block_buf),
1926 generate_tie([&](auto i) -> const auto& // return type should be reference
1927 { return ds_grid_buf[i]; },
1929
1930 // tuple of starting index of C/Ds blockwise copy
1931 const auto idx_c_ds_block_begin = container_concat(
1932 make_tuple(make_multi_index(0, 0, 0, 0)),
1934 [&](auto) {
1935 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1936 },
1938
1939 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1940 c_grid_desc_mblock_mperblock_nblock_nperblock;
1941
1942 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1943 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1944 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1945
1946 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
1948 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1950 decltype(c_ds_desc_refs),
1951 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1952 CElementwiseOperation,
1953 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1954 // support arbitray type
1955 Sequence<1,
1956 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1957 1,
1958 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1959 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1960 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1961 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1962 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1963 3, // index_t SrcVectorDim,
1964 3, // index_t DstVectorDim,
1965 CDEShuffleBlockTransferScalarPerVectors,
1970 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1971 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
1972 {c_ds_desc_refs,
1973 idx_c_ds_block_begin,
1974 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1975 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
1976 c_element_op};
1977
1978 // space filling curve for threadwise C in VGPR
1979 constexpr auto sfc_c_vgpr =
1982 Sequence<CShuffleMXdlPerWavePerShuffle,
1983 CShuffleNXdlPerWavePerShuffle,
1984 1,
1985 1,
1986 M2,
1987 1,
1988 M4,
1989 1>>{};
1990
1991 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1992
1993 // space filling curve for shuffled blockwise C/D/E
1994 constexpr auto sfc_cde_block =
1997 Sequence<1,
1998 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1999 1,
2000 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2001
2002 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2003
2004 static_for<0, num_access, 1>{}([&](auto access_id) {
2005 // make sure it's safe to write to LDS
2007
2008 // each thread write its data from VGPR to LDS
2009 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2010 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2011 c_thread_buf,
2012 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2013 c_shuffle_block_buf);
2014
2015 // make sure it's safe to read from LDS
2017
2018 // each block copy its data from LDS to global
2019 cde_block_copy_lds_and_global.Run(
2020 c_ds_desc_refs,
2021 c_ds_buf_refs,
2022 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2023 tie(c_grid_buf));
2024
2025 if constexpr(access_id < num_access - 1)
2026 {
2027 constexpr auto cde_lds_and_global_step =
2028 sfc_cde_block.GetForwardStep(access_id);
2029
2030 // move on Ds
2031 static_for<0, NumDTensor, 1>{}([&](auto i) {
2032 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2033 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2034 });
2035
2036 // move on E
2037 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2038 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2039 I0,
2040 cde_lds_and_global_step);
2041 }
2042 });
2043 }
2044 }
2045};
2046
2047} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:39
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:82
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:652
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:694
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:653
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:687
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:689
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:693
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:692
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:690
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:688
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:647
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:640
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:639
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:643
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:646
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:637
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:633
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:636
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:644
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:635
__host__ __device__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:595
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:645
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:632
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:638
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:622
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:634
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:642
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:641
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:699
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:720
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:168
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:862
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map, const index_t k_id, const index_t Kt)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:1195
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:1143
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const index_t k_id, const index_t Kt)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:1602
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map, const index_t k_id, const index_t Kt)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:1636
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340