30template <
typename GridwiseGemm,
31 bool HasMainKBlockLoop,
36#if CK_USE_LAUNCH_BOUNDS
42#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
45 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
47 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
49 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
50 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
51 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
65template <
typename GridwiseGemm,
66 bool HasMainKBlockLoop,
71#if CK_USE_LAUNCH_BOUNDS
77#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
78 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
82 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
83 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
85 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
87 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
88 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
89 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
104template <
typename ALayout,
110 typename AccDataType,
111 typename CShuffleDataType,
114 typename AElementwiseOperation,
115 typename BElementwiseOperation,
116 typename CElementwiseOperation,
128 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
129 typename ABlockTransferThreadClusterArrangeOrder,
130 typename ABlockTransferSrcAccessOrder,
131 index_t ABlockTransferSrcVectorDim,
132 index_t ABlockTransferSrcScalarPerVector,
133 index_t ABlockTransferDstScalarPerVector_AK1,
134 bool AThreadTransferSrcResetCoordinateAfterRun,
136 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
137 typename BBlockTransferThreadClusterArrangeOrder,
138 typename BBlockTransferSrcAccessOrder,
139 index_t BBlockTransferSrcVectorDim,
140 index_t BBlockTransferSrcScalarPerVector,
141 index_t BBlockTransferDstScalarPerVector_BK1,
142 bool BThreadTransferSrcResetCoordinateAfterRun,
144 index_t CShuffleMXdlPerWavePerShuffle,
145 index_t CShuffleNXdlPerWavePerShuffle,
146 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
147 typename CDEShuffleBlockTransferScalarPerVectors,
150 typename ComputeTypeA = CDataType,
151 typename ComputeTypeB = ComputeTypeA,
152 typename LDSTypeA = ADataType,
153 typename LDSTypeB = BDataType,
154 bool DoElementwiseBeforeCShuffle =
false,
155 bool DirectLoad =
false>
172 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
190 return static_cast<const DDataType*
>(
nullptr);
204 KPerBlock < 128 && MPerXdl == 16))
241 auto K_t = K_Batch * KPerBlock;
242 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
247 auto K_t = K_Batch * KPerBlock;
248 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
253 auto K_t = K_Batch * KPerBlock;
254 return (K + K_t - 1) / K_t * KPerBlock;
260 auto K_t = K_Batch * KReadVec;
261 return (K + K_t - 1) / K_t * KReadVec;
274 template <
typename Gr
idDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
278 if constexpr(!DirectLoad)
284 const index_t K = desc.GetLength(
I0) * desc.GetLength(
I2);
314 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
320 if constexpr(!DirectLoad)
356 const auto a_grid_desc_mraw_kraw = [&]() {
369 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
370 GemmSpec == GemmSpecialization::MNKPadding)
373 const auto a_grid_desc_m_k =
387 return a_grid_desc_ak0_m_ak1;
389 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
390 GemmSpec == GemmSpecialization::MNPadding)
394 a_grid_desc_mraw_kraw,
400 return a_grid_desc_ak0_m_ak1;
402 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
403 GemmSpec == GemmSpecialization::NKPadding)
407 a_grid_desc_mraw_kraw,
419 return a_grid_desc_ak0_m_ak1;
425 a_grid_desc_mraw_kraw,
431 return a_grid_desc_ak0_m_ak1;
438 const auto b_grid_desc_nraw_kraw = [&]() {
451 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
452 GemmSpec == GemmSpecialization::MNKPadding)
455 const auto b_grid_desc_n_k =
469 return b_grid_desc_bk0_n_bk1;
471 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
472 GemmSpec == GemmSpecialization::MNPadding)
476 b_grid_desc_nraw_kraw,
482 return b_grid_desc_bk0_n_bk1;
484 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
485 GemmSpec == GemmSpecialization::MKPadding)
489 b_grid_desc_nraw_kraw,
501 return b_grid_desc_bk0_n_bk1;
507 b_grid_desc_nraw_kraw,
513 return b_grid_desc_bk0_n_bk1;
517 template <
typename ABlockDesc_AK0_M_AK1>
518 __host__ __device__
static constexpr auto
521 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
526 template <
typename BBlockDesc_BK0_N_BK1>
527 __host__ __device__
static constexpr auto
530 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
535 template <
typename ELayout>
536 __host__ __device__
static auto
539 const auto c_grid_desc_mraw_nraw = [&]() {
551 "The layout configuration is not supported! "
552 "Only support Row & Col major.");
565 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
566 GemmSpec == GemmSpecialization::MNKPadding)
575 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
576 GemmSpec == GemmSpecialization::MKPadding)
580 c_grid_desc_mraw_nraw,
585 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
586 GemmSpec == GemmSpecialization::NKPadding)
590 c_grid_desc_mraw_nraw,
598 return c_grid_desc_mraw_nraw;
614 template <
typename DsGr
idDesc>
616 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
621 ds_grid_desc_m_n[i], MBlock, NBlock);
634 std::array<index_t, NumDTensor> StrideDs_,
658 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
661 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
662 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
663 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
689 const BDataType* p_b_grid_,
690 std::array<const void*, NumDTensor> p_ds_grid_,
691 CDataType* p_c_grid_,
697 std::array<index_t, NumDTensor> StrideDs_,
700 AElementwiseOperation a_element_op_,
701 BElementwiseOperation b_element_op_,
702 CElementwiseOperation c_element_op_)
703 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
718 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
754 if(k_id < karg.
KBatch - 1)
770 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
771 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
772 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
773#if defined(__gfx950__)
775 constexpr index_t ABlockLdsExtraM = 1;
777 constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
781 if constexpr(DirectLoad)
799 constexpr auto MLdsLayer = 32 * 4 / KPerBlock /
sizeof(LDSTypeA) < 1
801 : 32 * 4 / KPerBlock /
sizeof(LDSTypeA);
816 a_lds_block_desc_permuted,
824 a_lds_block_desc_ak0_mldslayer_m_ak1,
832 return a_lds_block_desc_ak0_m_ak1;
839 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
840 constexpr auto M1 = MPerBlock / M0;
842 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
843 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
844 constexpr auto KThreadRead = WaveSize / MPerXdl;
845 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
847 constexpr auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
849 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
850 constexpr auto KThreadReadPerm =
851 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
852 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
856 constexpr auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
858 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
860 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
866 Number<kfold * M0 / mpair>{},
885 a_lds_block_desc_permuted,
907 a_lds_block_desc_unmerged,
910 Number<KThreadWrite / kfold / KThreadReadPerm>{},
919 return a_lds_block_desc_ak0_m_ak1;
925 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
926 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
927 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
928#if defined(__gfx950__)
930 constexpr index_t BBlockLdsExtraN = 1;
932 constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
936 if constexpr(DirectLoad)
953 constexpr auto NLdsLayer = 32 * 4 / KPerBlock /
sizeof(LDSTypeB) < 1
955 : 32 * 4 / KPerBlock /
sizeof(LDSTypeB);
971 b_lds_block_desc_permuted,
979 b_lds_block_desc_bk0_nldslayer_n_bk1,
987 return b_lds_block_desc_bk0_n_bk1;
991 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
992 constexpr auto N1 = NPerBlock / N0;
994 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
995 constexpr auto K0PerThreadWrite =
BK0Number / KThreadWrite;
996 constexpr auto KThreadRead = WaveSize / NPerXdl;
997 constexpr auto K0PerThreadRead =
BK0Number / KThreadRead;
999 constexpr auto kfold = (
BK1Number * N0 *
sizeof(LDSTypeB) > 128)
1001 : 128 / (
BK1Number * N0 *
sizeof(LDSTypeB));
1002 constexpr auto KThreadReadPerm =
1003 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1004 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1008 constexpr auto npair = (
BK1Number * NPerXdl *
sizeof(LDSTypeB) > 128)
1010 : ((128 / (
BK1Number * NPerXdl *
sizeof(LDSTypeB))) > N0
1012 : 128 / (
BK1Number * NPerXdl *
sizeof(LDSTypeB)));
1018 Number<kfold * N0 / npair>{},
1037 b_lds_block_desc_permuted,
1059 b_lds_block_desc_unmerged,
1062 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1071 return b_lds_block_desc_bk0_n_bk1;
1077 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1078 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1080 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1087 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1105 ABlockTransferSrcScalarPerVector,
1106 BBlockTransferSrcScalarPerVector,
1127 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1130 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1133 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1136 constexpr auto c_block_size =
1137 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1139 return math::max((a_block_space_size_aligned *
sizeof(LDSTypeA) +
1140 b_block_space_size_aligned *
sizeof(LDSTypeB)),
1141 c_block_size *
sizeof(CShuffleDataType));
1148 constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter<
1157 CGlobalMemoryDataOperation_>();
1158 if constexpr(!valid)
1170 constexpr index_t KPerThread =
1171 KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops());
1172 if constexpr(KPerThread %
KPack != 0)
1178 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
1188 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1189 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1190 "Invalid tuning param!");
1192 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
1203 if(!(karg.M % MPerBlock == 0))
1206 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1207 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1221 if(!(karg.N % NPerBlock == 0))
1224 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1225 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1239 auto K_t = karg.KBatch * KPerBlock;
1240 if(!(karg.K % K_t == 0))
1243 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1244 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1245 <<
", in function: " << __func__ << std::endl;
1254 auto K_t = karg.KBatch * KReadVec;
1256 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1264 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1267 std::cout <<
"Arg K (" << karg.K
1268 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1269 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1270 << __LINE__ <<
", in function: " << __func__ << std::endl;
1278 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1281 std::cout <<
"Arg M (" << karg.M
1282 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1283 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1284 << __LINE__ <<
", in function: " << __func__ << std::endl;
1293 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1296 std::cout <<
"Arg N (" << karg.N
1297 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1298 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1299 << __LINE__ <<
", in function: " << __func__ << std::endl;
1307 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1310 std::cout <<
"Arg K (" << karg.K
1311 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1312 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1313 << __LINE__ <<
", in function: " << __func__ << std::endl;
1325 std::cout <<
"Arg N (" << karg.N
1326 <<
") value is not a multiple of "
1327 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1329 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1340 std::cout <<
"Arg M (" << karg.M
1341 <<
") value is not a multiple of "
1342 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1344 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1352 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1356 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1363 if(!(karg.M * karg.K *
sizeof(ADataType) <= TwoGB &&
1364 karg.N * karg.K *
sizeof(BDataType) <= TwoGB &&
1365 karg.M * karg.N *
sizeof(CDataType) <= TwoGB))
1376 const index_t num_loop = K / KPerBlock;
1378 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1383 const index_t num_loop = K / KPerBlock;
1385 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1388 template <
typename CGr
idDesc>
1390 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1399 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1406 template <
bool HasMainKBlockLoop,
1409 __device__
static void Run(
const ADataType* __restrict__ p_a_grid,
1410 const BDataType* __restrict__ p_b_grid,
1412 CDataType* __restrict__ p_c_grid,
1413 void* __restrict__ p_shared,
1414 const Problem& problem,
1415 AElementwiseOperation a_element_op,
1416 BElementwiseOperation b_element_op,
1417 CElementwiseOperation c_element_op)
1433 template <
typename Block2CTileMap,
1434 bool HasMainKBlockLoop,
1437 __device__
static void Run(
const ADataType* __restrict__ p_a_grid,
1438 const BDataType* __restrict__ p_b_grid,
1440 CDataType* __restrict__ p_c_grid,
1441 void* __restrict__ p_shared,
1442 const Problem& problem,
1443 AElementwiseOperation a_element_op,
1444 BElementwiseOperation b_element_op,
1445 CElementwiseOperation c_element_op,
1446 const Block2CTileMap& block_2_ctile_map)
1449 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1451 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1454 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1456 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1469 a_grid_desc_ak0_m_ak1,
1470 b_grid_desc_bk0_n_bk1,
1475 template <
bool HasMainKBlockLoop,
1478 typename Block2CTileMap,
1479 typename AGridDesc_AK0_M_K1,
1480 typename BGridDesc_BK0_N_K1,
1481 typename DsGridDesc_M_N,
1482 typename CGridDesc_M_N>
1483 __device__
static void Run(
const ADataType* __restrict__ p_a_grid,
1484 const BDataType* __restrict__ p_b_grid,
1486 CDataType* __restrict__ p_c_grid,
1487 void* __restrict__ p_shared,
1488 const Problem& problem,
1489 [[maybe_unused]] AElementwiseOperation a_element_op,
1490 [[maybe_unused]] BElementwiseOperation b_element_op,
1491 CElementwiseOperation c_element_op,
1492 const Block2CTileMap& block_2_ctile_map,
1493 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1494 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1495 const DsGridDesc_M_N& ds_grid_desc_m_n,
1496 const CGridDesc_M_N& c_grid_desc_m_n)
1500 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1502 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1504 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1506 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1509 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1511 const auto block_work_idx =
1514 if(!block_2_ctile_map.ValidCTileIndex(
1516 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1517 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1522 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1523 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1526 const index_t m_block_data_idx_on_grid =
1527 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1529 const index_t n_block_data_idx_on_grid =
1530 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1541 auto get_a_blockwise_copy = [&]() {
1542 if constexpr(DirectLoad)
1547 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1548 ABlockTransferThreadClusterArrangeOrder,
1551 decltype(a_grid_desc_ak0_m_ak1),
1552 decltype(a_block_desc_ak0_m_ak1),
1553 ABlockTransferSrcAccessOrder,
1554 ABlockTransferSrcVectorDim,
1556 ABlockTransferSrcScalarPerVector>(
1557 a_grid_desc_ak0_m_ak1,
1559 a_block_desc_ak0_m_ak1,
1566 AElementwiseOperation,
1570 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1571 ABlockTransferThreadClusterArrangeOrder,
1574 decltype(a_grid_desc_ak0_m_ak1),
1575 decltype(a_block_desc_ak0_m_ak1),
1576 ABlockTransferSrcAccessOrder,
1578 ABlockTransferSrcVectorDim,
1580 ABlockTransferSrcScalarPerVector,
1581 ABlockTransferDstScalarPerVector_AK1,
1584 AThreadTransferSrcResetCoordinateAfterRun,
1586 BlockwiseGemmPipe::GlobalBufferNum>(
1587 a_grid_desc_ak0_m_ak1,
1590 a_block_desc_ak0_m_ak1,
1597 auto get_b_blockwise_copy = [&]() {
1598 if constexpr(DirectLoad)
1603 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1604 BBlockTransferThreadClusterArrangeOrder,
1607 decltype(b_grid_desc_bk0_n_bk1),
1608 decltype(b_block_desc_bk0_n_bk1),
1609 BBlockTransferSrcAccessOrder,
1610 BBlockTransferSrcVectorDim,
1612 BBlockTransferSrcScalarPerVector>(
1613 b_grid_desc_bk0_n_bk1,
1615 b_block_desc_bk0_n_bk1,
1622 BElementwiseOperation,
1626 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1627 BBlockTransferThreadClusterArrangeOrder,
1630 decltype(b_grid_desc_bk0_n_bk1),
1631 decltype(b_block_desc_bk0_n_bk1),
1632 BBlockTransferSrcAccessOrder,
1634 BBlockTransferSrcVectorDim,
1636 BBlockTransferSrcScalarPerVector,
1637 BBlockTransferDstScalarPerVector_BK1,
1640 BThreadTransferSrcResetCoordinateAfterRun,
1642 BlockwiseGemmPipe::GlobalBufferNum>(
1643 b_grid_desc_bk0_n_bk1,
1646 b_block_desc_bk0_n_bk1,
1652 auto a_blockwise_copy = get_a_blockwise_copy();
1653 auto b_blockwise_copy = get_b_blockwise_copy();
1657 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1661 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1664 static_cast<LDSTypeB*
>(p_shared) +
1665 a_block_space_size_aligned *
sizeof(LDSTypeA) /
sizeof(LDSTypeB),
1666 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1672 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1674 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1676 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1677 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1681 a_block_desc_ak0_m_ak1,
1685 a_block_slice_copy_step,
1686 b_grid_desc_bk0_n_bk1,
1687 b_block_desc_bk0_n_bk1,
1691 b_block_slice_copy_step,
1693 num_k_block_main_loop);
1697 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1698 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1701 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1702 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1705 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1706 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1710 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1711 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1713 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1714 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1715 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1716 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1717 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1718 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1719 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1720 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1722 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1726 static_cast<CShuffleDataType*
>(p_shared),
1727 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1730 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1750 const auto c_thread_mtx_on_block =
1751 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1753 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1754 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1756 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1762 const auto m_thread_data_on_block_idx =
1763 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1766 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1772 const auto n_thread_data_on_block_idx =
1773 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1777 const auto& vpgr_to_lds_element_op = [&] {
1778 if constexpr(DoElementwiseBeforeCShuffle)
1780 return c_element_op;
1784 return pass_through;
1787 const auto& lds_to_global_element_op = [&] {
1788 if constexpr(!DoElementwiseBeforeCShuffle)
1790 return c_element_op;
1794 return pass_through;
1802 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1803 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1805 CElementwiseOperation,
1807 Sequence<CShuffleMXdlPerWavePerShuffle,
1808 CShuffleNXdlPerWavePerShuffle,
1820 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1823 m_thread_data_on_block_idx[
I1],
1824 n_thread_data_on_block_idx[
I1],
1825 m_thread_data_on_block_idx[
I2],
1826 m_thread_data_on_block_idx[
I3],
1827 m_thread_data_on_block_idx[
I4],
1828 n_thread_data_on_block_idx[
I2]),
1829 vpgr_to_lds_element_op()};
1831 using EDataType = CDataType;
1833 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1835 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1840 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1846 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1848 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1853 tie(c_shuffle_block_buf),
1855 {
return ds_grid_buf[i]; },
1867 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1868 c_grid_desc_mblock_mperblock_nblock_nperblock;
1870 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1871 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1872 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1878 decltype(c_ds_desc_refs),
1879 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1881 CElementwiseOperation,
1886 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1888 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1889 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1895 CDEShuffleBlockTransferScalarPerVectors,
1903 idx_c_ds_block_begin,
1904 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1906 lds_to_global_element_op()};
1909 constexpr auto sfc_c_vgpr =
1912 Sequence<CShuffleMXdlPerWavePerShuffle,
1913 CShuffleNXdlPerWavePerShuffle,
1921 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1924 constexpr auto sfc_cde_block =
1928 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1930 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1932 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1939 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1940 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1942 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1943 c_shuffle_block_buf);
1949 cde_block_copy_lds_and_global.Run(
1952 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1955 if constexpr(access_id < num_access - 1)
1957 constexpr auto cde_lds_and_global_step =
1958 sfc_cde_block.GetForwardStep(access_id);
1962 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1963 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1967 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1968 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1970 cde_lds_and_global_step);
1976 template <
bool HasMainKBlockLoop,
1979 __device__
static void Run_2Lds(
const ADataType* __restrict__ p_a_grid,
1980 const BDataType* __restrict__ p_b_grid,
1982 CDataType* __restrict__ p_c_grid,
1983 void* __restrict__ p_shared_0,
1984 void* __restrict__ p_shared_1,
1985 const Problem& problem,
1986 AElementwiseOperation a_element_op,
1987 BElementwiseOperation b_element_op,
1988 CElementwiseOperation c_element_op)
2006 template <
typename Block2CTileMap,
2007 bool HasMainKBlockLoop,
2010 __device__
static void Run_2Lds(
const ADataType* __restrict__ p_a_grid,
2011 const BDataType* __restrict__ p_b_grid,
2013 CDataType* __restrict__ p_c_grid,
2014 void* __restrict__ p_shared_0,
2015 void* __restrict__ p_shared_1,
2016 const Problem& problem,
2017 AElementwiseOperation a_element_op,
2018 BElementwiseOperation b_element_op,
2019 CElementwiseOperation c_element_op,
2020 const Block2CTileMap& block_2_ctile_map)
2023 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2025 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2028 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2030 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2043 a_grid_desc_ak0_m_ak1,
2044 b_grid_desc_bk0_n_bk1,
2049 template <
bool HasMainKBlockLoop,
2052 typename Block2CTileMap,
2053 typename AGridDesc_AK0_M_K1,
2054 typename BGridDesc_BK0_N_K1,
2055 typename DsGridDesc_M_N,
2056 typename CGridDesc_M_N>
2057 __device__
static void Run_2Lds(
const ADataType* __restrict__ p_a_grid,
2058 const BDataType* __restrict__ p_b_grid,
2060 CDataType* __restrict__ p_c_grid,
2061 void* __restrict__ p_shared_0,
2062 void* __restrict__ p_shared_1,
2063 const Problem& problem,
2064 [[maybe_unused]] AElementwiseOperation a_element_op,
2065 [[maybe_unused]] BElementwiseOperation b_element_op,
2066 CElementwiseOperation c_element_op,
2067 const Block2CTileMap& block_2_ctile_map,
2068 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
2069 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
2070 const DsGridDesc_M_N& ds_grid_desc_m_n,
2071 const CGridDesc_M_N& c_grid_desc_m_n)
2074 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2076 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2079 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2081 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2083 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2085 const auto block_work_idx =
2088 if(!block_2_ctile_map.ValidCTileIndex(
2090 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
2091 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
2096 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
2097 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
2100 const index_t m_block_data_idx_on_grid =
2101 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
2103 const index_t n_block_data_idx_on_grid =
2104 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2115 auto get_a_blockwise_copy = [&]() {
2116 if constexpr(DirectLoad)
2121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2122 ABlockTransferThreadClusterArrangeOrder,
2125 decltype(a_grid_desc_ak0_m_ak1),
2126 decltype(a_block_desc_ak0_m_ak1),
2127 ABlockTransferSrcAccessOrder,
2128 ABlockTransferSrcVectorDim,
2130 ABlockTransferSrcScalarPerVector>(
2131 a_grid_desc_ak0_m_ak1,
2133 a_block_desc_ak0_m_ak1,
2140 AElementwiseOperation,
2144 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2145 ABlockTransferThreadClusterArrangeOrder,
2148 decltype(a_grid_desc_ak0_m_ak1),
2149 decltype(a_block_desc_ak0_m_ak1),
2150 ABlockTransferSrcAccessOrder,
2152 ABlockTransferSrcVectorDim,
2154 ABlockTransferSrcScalarPerVector,
2155 ABlockTransferDstScalarPerVector_AK1,
2158 AThreadTransferSrcResetCoordinateAfterRun,
2160 BlockwiseGemmPipe::GlobalBufferNum>(
2161 a_grid_desc_ak0_m_ak1,
2164 a_block_desc_ak0_m_ak1,
2171 auto get_b_blockwise_copy = [&]() {
2172 if constexpr(DirectLoad)
2177 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2178 BBlockTransferThreadClusterArrangeOrder,
2181 decltype(b_grid_desc_bk0_n_bk1),
2182 decltype(b_block_desc_bk0_n_bk1),
2183 BBlockTransferSrcAccessOrder,
2184 BBlockTransferSrcVectorDim,
2186 BBlockTransferSrcScalarPerVector>(
2187 b_grid_desc_bk0_n_bk1,
2189 b_block_desc_bk0_n_bk1,
2196 BElementwiseOperation,
2200 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2201 BBlockTransferThreadClusterArrangeOrder,
2204 decltype(b_grid_desc_bk0_n_bk1),
2205 decltype(b_block_desc_bk0_n_bk1),
2206 BBlockTransferSrcAccessOrder,
2208 BBlockTransferSrcVectorDim,
2210 BBlockTransferSrcScalarPerVector,
2211 BBlockTransferDstScalarPerVector_BK1,
2214 BThreadTransferSrcResetCoordinateAfterRun,
2216 BlockwiseGemmPipe::GlobalBufferNum>(
2217 b_grid_desc_bk0_n_bk1,
2220 b_block_desc_bk0_n_bk1,
2226 auto a_blockwise_copy = get_a_blockwise_copy();
2227 auto b_blockwise_copy = get_b_blockwise_copy();
2231 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2234 static_cast<LDSTypeA*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2237 static_cast<LDSTypeB*
>(p_shared_0) +
2238 a_block_space_size_aligned *
sizeof(LDSTypeA) /
sizeof(LDSTypeB),
2239 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2242 static_cast<LDSTypeA*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2245 static_cast<LDSTypeB*
>(p_shared_1) +
2246 a_block_space_size_aligned *
sizeof(LDSTypeA) /
sizeof(LDSTypeB),
2247 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2249 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2250 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2256 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2258 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2260 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2261 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2265 a_block_desc_ak0_m_ak1,
2269 a_block_slice_copy_step,
2270 b_grid_desc_bk0_n_bk1,
2271 b_block_desc_bk0_n_bk1,
2275 b_block_slice_copy_step,
2277 num_k_block_main_loop);
2281 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2282 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2285 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2286 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2289 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2290 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2294 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2295 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2297 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2298 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2299 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2300 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2301 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2302 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2303 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2304 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2306 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2310 static_cast<CShuffleDataType*
>(p_shared_0),
2311 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2314 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2334 const auto c_thread_mtx_on_block =
2335 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2337 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2338 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2340 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2346 const auto m_thread_data_on_block_idx =
2347 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2350 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2356 const auto n_thread_data_on_block_idx =
2357 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2361 const auto& vpgr_to_lds_element_op = [&] {
2362 if constexpr(DoElementwiseBeforeCShuffle)
2364 return c_element_op;
2368 return pass_through;
2371 const auto& lds_to_global_element_op = [&] {
2372 if constexpr(!DoElementwiseBeforeCShuffle)
2374 return c_element_op;
2378 return pass_through;
2386 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2387 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2389 CElementwiseOperation,
2391 Sequence<CShuffleMXdlPerWavePerShuffle,
2392 CShuffleNXdlPerWavePerShuffle,
2404 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2407 m_thread_data_on_block_idx[
I1],
2408 n_thread_data_on_block_idx[
I1],
2409 m_thread_data_on_block_idx[
I2],
2410 m_thread_data_on_block_idx[
I3],
2411 m_thread_data_on_block_idx[
I4],
2412 n_thread_data_on_block_idx[
I2]),
2413 vpgr_to_lds_element_op()};
2415 using EDataType = CDataType;
2417 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2419 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2424 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2430 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2432 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2437 tie(c_shuffle_block_buf),
2439 {
return ds_grid_buf[i]; },
2451 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2452 c_grid_desc_mblock_mperblock_nblock_nperblock;
2454 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
2455 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2456 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2462 decltype(c_ds_desc_refs),
2463 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2465 CElementwiseOperation,
2470 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2472 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2473 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2479 CDEShuffleBlockTransferScalarPerVectors,
2487 idx_c_ds_block_begin,
2488 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2490 lds_to_global_element_op()};
2493 constexpr auto sfc_c_vgpr =
2496 Sequence<CShuffleMXdlPerWavePerShuffle,
2497 CShuffleNXdlPerWavePerShuffle,
2505 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2508 constexpr auto sfc_cde_block =
2512 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2514 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2516 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2523 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2524 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2526 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2527 c_shuffle_block_buf);
2533 cde_block_copy_lds_and_global.Run(
2536 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2539 if constexpr(access_id < num_access - 1)
2541 constexpr auto cde_lds_and_global_step =
2542 sfc_cde_block.GetForwardStep(access_id);
2546 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2547 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2551 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2552 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2554 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:40
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:75
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:686
AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:727
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:722
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:723
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:725
BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:728
__host__ Argument()=default
CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:729
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:688
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:724
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:675
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:672
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:673
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:670
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:671
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:679
__host__ __device__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:629
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:669
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:668
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:680
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:667
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:681
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:676
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:674
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:666
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:656
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:678
__host__ __device__ Problem()=default
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:677
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:765
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:764
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:734
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateKBlockLoopTailNum __host__ static __device__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1381
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I2 static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:164
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::BK1Number static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:178
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeBGridDescriptor_BK0_N_BK1 __host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:435
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateMPadded __host__ static __device__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:224
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Run static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1409
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::GetSharedMemoryNumberOfByte static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1117
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:195
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1 static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:768
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeDsGridPointer static constexpr auto MakeDsGridPointer()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:184
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:234
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::is_single_rate_mfma static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:198
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateMBlock __host__ static __device__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:264
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:251
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::lcm_AK1_BK1 static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:197
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::KPack static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:208
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeCGridDescriptor_M_N __host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:537
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1374
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1389
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::BK0Number static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:176
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::BlockwiseGemmPipe remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, false >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1090
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateNPadded __host__ static __device__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:229
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::AK1Number static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:177
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateBK0Padded __host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:245
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1075
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateAK0Padded __host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:239
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::CShuffleBlockTransferScalarPerVector_NPerBlock static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:171
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I6 static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:168
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeDsGridDescriptor_M_N __host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:603
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:217
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeAGridDescriptor_AK0_M_AK1 __host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:353
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I5 static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:167
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::DirectLoadEnabled static constexpr bool DirectLoadEnabled
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:180
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I7 static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:169
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:219
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared_0, void *__restrict__ p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDesc_M_N &ds_grid_desc_m_n, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:2057
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:615
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I0 static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:162
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::AK0Number static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:175
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeAMmaTileDescriptor_M0_M1_M2_K __host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:519
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Run static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDesc_M_N &ds_grid_desc_m_n, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1483
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I1 static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:163
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeBMmaTileDescriptor_N0_N1_N2_K __host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:528
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::TransformGrid __host__ static __device__ auto TransformGrid(GridDesc_K0_MN_K1_T &desc)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:275
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateNBlock __host__ static __device__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:269
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Run static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1437
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::NumDTensor static constexpr index_t NumDTensor
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:182
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::MakeGemmMmaTileDescriptor __host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:315
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I4 static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:166
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::is_scale_mfma static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:207
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared_0, void *__restrict__ p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1979
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateKRead __host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:257
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer &p_ds_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared_0, void *__restrict__ p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:2010
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::IsValidCompilationParameter static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1146
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::I3 static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:165
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1186
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1 static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:923
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >< math::max(NXdlPerWave64, 1)>::Block2CTileMapDefault BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMapDefault
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1404
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340