1162 const index_t* p_sorted_expert_ids,
1163 const index_t* p_max_token_id,
1164 const ADataType* p_a_grid,
1165 const BDataType* p_b_grid,
1167 CDataType* p_c_grid,
1169 const Problem& problem,
1170 AElementwiseOperation a_element_op,
1171 BElementwiseOperation b_element_op,
1172 CElementwiseOperation c_element_op)
1178 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1184 const auto b_grid_desc_bpreshuffled =
1187 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1192 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1194 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1195 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1197 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1198 if(expert_block_id * MPerBlock >= max_token_id)
1201 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1202 const auto block_mn = [&]() -> std::pair<int, int> {
1203 if constexpr(NSwizzle)
1205 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1206 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1207 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1208 const index_t expert_swizzle =
1209 ecnt > 0 ? ecnt : 1;
1210 const index_t bid_new = blockIdx.x - prefix_block;
1211 const index_t nid = __builtin_amdgcn_readfirstlane(
1212 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1214 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1219 return {blockIdx.x, blockIdx.y};
1223 const index_t block_n_id = block_mn.first;
1224 const index_t block_m_id = block_mn.second;
1226 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1229 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1230 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1231 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1232 constexpr auto AKThreads = AK0Threads * AK1Threads;
1233 constexpr auto AMRepeats = MPerBlock / AMThreads;
1234 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1236 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1240 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1241 index_t token_offset = fused_token & 0xffffff;
1242 if constexpr(!IsInputGemm)
1244 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1246 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.K;
1248 const IndexType expert_stride =
1249 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1250 const IndexType expert_offset = expert_id * expert_stride /
BPackedSize;
1252 const index_t n_block_data_idx_on_grid =
1253 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1256 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1258 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1268 AElementwiseOperation,
1272 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1273 ABlockTransferThreadClusterArrangeOrder,
1276 decltype(a_grid_desc_ak0_m_ak1),
1277 decltype(a_block_desc_ak0_m_ak1),
1278 ABlockTransferSrcAccessOrder,
1280 ABlockTransferSrcVectorDim,
1282 ABlockTransferSrcScalarPerVector,
1283 ABlockTransferDstScalarPerVector_AK1,
1286 AThreadTransferSrcResetCoordinateAfterRun,
1290 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1293 a_block_desc_ak0_m_ak1,
1301 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1306 decltype(b_grid_desc_bpreshuffled),
1307 decltype(b_block_desc_bk0_n_bk1),
1311 BBlockTransferSrcScalarPerVector,
1312 BThreadTransferSrcResetCoordinateAfterRun,
1313 true>(b_grid_desc_bpreshuffled,
1322 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1328 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1330 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1331 decltype(c_thread_buf) c_thread_buf_up;
1335 c_thread_buf.num_of_v_,
1336 c_thread_buf.s_per_v,
1340 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1341 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1343 if constexpr(IsInputGemm)
1345 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1347 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1351 decltype(b_grid_desc_bpreshuffled),
1352 decltype(b_block_desc_bk0_n_bk1),
1356 BBlockTransferSrcScalarPerVector,
1357 BThreadTransferSrcResetCoordinateAfterRun,
1358 true>(b_grid_desc_bpreshuffled,
1365 a_grid_desc_ak0_m_ak1,
1366 a_block_desc_ak0_m_ak1,
1370 a_block_slice_copy_step,
1371 b_grid_desc_bpreshuffled,
1373 b_blockwise_copy_up,
1377 b_block_slice_copy_step,
1380 num_k_block_main_loop);
1385 a_grid_desc_ak0_m_ak1,
1386 a_block_desc_ak0_m_ak1,
1390 a_block_slice_copy_step,
1391 b_grid_desc_bpreshuffled,
1395 b_block_slice_copy_step,
1397 num_k_block_main_loop);
1402 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1403 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1406 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1409 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1410 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1414 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1415 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1417 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1418 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1419 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1420 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1421 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1422 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1423 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1424 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1427 const float* p_sorted_weights_0 = p_ds_grid[
I0];
1428 const float* p_scale_b = p_ds_grid[
I1];
1430 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1431 static_assert(M4 == 4 || M4 == 8);
1435 if(p_sorted_weights_0 !=
nullptr && p_scale_b !=
nullptr)
1437 if constexpr(PerTokenQuant)
1439 constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
1440 p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
1445 p_scale_b += expert_id;
1451 const float scale_b = p_scale_b[n0 *
NWave * NPerXdl * PerTokenQuant];
1454 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1455 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1456 if constexpr(PerTokenQuant)
1460 p_sorted_token_ids + m_pos);
1462 if constexpr(MulRoutedWeight)
1465 p_ds_grid[
I2] + m_pos);
1468 float scale_a = [&]() {
1469 if constexpr(PerTokenQuant)
1472 scale_token_ids.template AsType<index_t>()[m4];
1473 const index_t token_offset = fused_token & 0xffffff;
1474 return token_offset < problem.NumTokens
1475 ? p_sorted_weights_0[IsInputGemm
1485 return p_sorted_weights_0[0];
1489 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1492 if constexpr(IsInputGemm)
1494 if constexpr(ActivationOperation == Activation::silu_and_mul)
1496 const float scale_up =
1497 p_scale_b[(n0 *
NWave * NPerXdl + problem.N) *
1499 float gate = scale_a * scale_b * c_thread_buf[cidx];
1500 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1501 if constexpr(MulRoutedWeight)
1503 gate = gate * topk_weights.template AsType<float>()[m4];
1504 up = up * topk_weights.template AsType<float>()[m4];
1512 c_thread_buf_fp32(cidx) = gate * up;
1514 else if(ActivationOperation == Activation::gelu_and_mul)
1516 const float scale_up =
1517 p_scale_b[(n0 *
NWave * NPerXdl + problem.N) *
1519 float gate = scale_a * scale_b * c_thread_buf[cidx];
1520 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1521 if constexpr(MulRoutedWeight)
1523 gate = gate * topk_weights.template AsType<float>()[m4];
1524 up = up * topk_weights.template AsType<float>()[m4];
1532 c_thread_buf_fp32(cidx) = gate * up;
1537 c_thread_buf_fp32(cidx) =
1538 scale_a * scale_b * c_thread_buf[cidx];
1539 if constexpr(MulRoutedWeight)
1541 c_thread_buf_fp32(cidx) =
1542 c_thread_buf_fp32(cidx) *
1543 topk_weights.template AsType<float>()[m4];
1557 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1558 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1559 if constexpr(MulRoutedWeight)
1562 p_ds_grid[
I2] + m_pos);
1566 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1570 if constexpr(IsInputGemm)
1572 if constexpr(ActivationOperation == Activation::silu_and_mul)
1574 float gate = c_thread_buf[cidx];
1575 float up = c_thread_buf_up[cidx];
1576 if constexpr(MulRoutedWeight)
1578 gate = gate * topk_weights.template AsType<float>()[m4];
1579 up = up * topk_weights.template AsType<float>()[m4];
1582 c_thread_buf_fp32(cidx) = gate * up;
1584 else if(ActivationOperation == Activation::gelu_and_mul)
1586 float gate = c_thread_buf[cidx];
1587 float up = c_thread_buf_up[cidx];
1588 if constexpr(MulRoutedWeight)
1590 gate = gate * topk_weights.template AsType<float>()[m4];
1591 up = up * topk_weights.template AsType<float>()[m4];
1594 c_thread_buf_fp32(cidx) = gate * up;
1599 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1600 if constexpr(MulRoutedWeight)
1602 c_thread_buf_fp32(cidx) =
1603 topk_weights.template AsType<float>()[m4] *
1604 c_thread_buf_fp32[cidx];
1613 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1617 static_cast<CShuffleDataType*
>(p_shared),
1618 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1621 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1641 const auto c_thread_mtx_on_block =
1642 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1644 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1645 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1647 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1653 const auto m_thread_data_on_block_idx =
1654 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1657 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1663 const auto n_thread_data_on_block_idx =
1664 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1668 auto c_thread_copy_vgpr_to_lds =
1671 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1672 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1674 Sequence<CShuffleMXdlPerWavePerShuffle,
1675 CShuffleNXdlPerWavePerShuffle,
1688 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1691 m_thread_data_on_block_idx[
I1],
1692 n_thread_data_on_block_idx[
I1],
1693 m_thread_data_on_block_idx[
I2],
1694 m_thread_data_on_block_idx[
I3],
1695 m_thread_data_on_block_idx[
I4],
1696 n_thread_data_on_block_idx[
I2]),
1699 using EDataType = CDataType;
1702 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1704 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1706 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1711 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1717 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1719 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1724 tie(c_shuffle_block_buf),
1726 {
return ds_grid_buf[i]; },
1730 const auto idx_c_ds_block_begin =
1740 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1741 c_grid_desc_mblock_mperblock_nblock_nperblock;
1743 using CDEBlockTransferCluster =
1744 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1745 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1746 constexpr index_t scatter_weight_idx = 3;
1751 decltype(c_ds_desc_refs),
1752 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1753 CElementwiseOperation,
1757 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1759 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1760 CDEBlockTransferCluster,
1766 CDEShuffleBlockTransferScalarPerVectors,
1778 idx_c_ds_block_begin,
1779 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1784 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1785 constexpr auto sfc_c_vgpr =
1788 Sequence<CShuffleMXdlPerWavePerShuffle,
1789 CShuffleNXdlPerWavePerShuffle,
1797 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1800 constexpr auto sfc_cde_block =
1804 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1806 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1808 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1809 constexpr auto EMThreads =
1810 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1811 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1812 constexpr auto ENThreads =
1813 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1818 auto dstidx = sfc_cde_block.GetIndex(access_id);
1820 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1822 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1823 IndexType token_offset = fused_token & 0xffffff;
1824 if constexpr(IsInputGemm)
1826 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1828 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.N;
1834 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1835 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1837 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1838 c_shuffle_block_buf);
1844 cde_block_copy_lds_and_global.Run(
1847 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1851 if constexpr(access_id < num_access - 1)
1853 constexpr auto cde_lds_and_global_step =
1854 sfc_cde_block.GetForwardStep(access_id);
1858 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1859 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1863 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1864 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1866 cde_lds_and_global_step);
1876 const index_t* p_sorted_expert_ids,
1877 const index_t* p_max_token_id,
1878 const ADataType* p_a_grid,
1879 const BDataType* p_b_grid,
1881 CDataType* p_c_grid,
1884 const Problem& problem,
1885 AElementwiseOperation a_element_op,
1886 BElementwiseOperation b_element_op,
1887 CElementwiseOperation c_element_op)
1893 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1899 const auto b_grid_desc_bpreshuffled =
1902 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1907 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1909 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1910 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1912 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1913 if(expert_block_id * MPerBlock >= max_token_id)
1916 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1917 const auto block_mn = [&]() -> std::pair<int, int> {
1918 if constexpr(NSwizzle)
1920 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1921 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1922 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1923 const index_t expert_swizzle =
1924 ecnt > 0 ? ecnt : 1;
1925 const index_t bid_new = blockIdx.x - prefix_block;
1926 const index_t nid = __builtin_amdgcn_readfirstlane(
1927 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1929 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1934 return {blockIdx.x, blockIdx.y};
1938 const index_t block_n_id = block_mn.first;
1939 const index_t block_m_id = block_mn.second;
1941 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1944 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1945 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1946 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1947 constexpr auto AKThreads = AK0Threads * AK1Threads;
1948 constexpr auto AMRepeats = MPerBlock / AMThreads;
1949 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1951 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1955 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1956 index_t token_offset = fused_token & 0xffffff;
1957 if constexpr(!IsInputGemm)
1959 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1961 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.K;
1963 const IndexType expert_stride =
1964 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1965 const IndexType expert_offset = expert_id * expert_stride /
BPackedSize;
1967 const index_t n_block_data_idx_on_grid =
1968 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1971 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1973 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1984 AElementwiseOperation,
1988 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1989 ABlockTransferThreadClusterArrangeOrder,
1992 decltype(a_grid_desc_ak0_m_ak1),
1993 decltype(a_block_desc_ak0_m_ak1),
1994 ABlockTransferSrcAccessOrder,
1996 ABlockTransferSrcVectorDim,
1998 ABlockTransferSrcScalarPerVector,
1999 ABlockTransferDstScalarPerVector_AK1,
2002 AThreadTransferSrcResetCoordinateAfterRun,
2006 2>(a_grid_desc_ak0_m_ak1,
2009 a_block_desc_ak0_m_ak1,
2017 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2019 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2020 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2025 decltype(b_grid_desc_bpreshuffled),
2026 decltype(b_block_desc_bk0_n_bk1),
2030 BBlockTransferSrcScalarPerVector,
2031 BThreadTransferSrcResetCoordinateAfterRun,
2032 true>(b_grid_desc_bpreshuffled,
2041 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2043 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2044 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2050 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2052 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2053 decltype(c_thread_buf) c_thread_buf_up;
2057 c_thread_buf.num_of_v_,
2058 c_thread_buf.s_per_v,
2062 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2063 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2066 if constexpr(IsInputGemm)
2068 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2070 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2074 decltype(b_grid_desc_bpreshuffled),
2075 decltype(b_block_desc_bk0_n_bk1),
2079 BBlockTransferSrcScalarPerVector,
2080 BThreadTransferSrcResetCoordinateAfterRun,
2081 true>(b_grid_desc_bpreshuffled,
2087 a_grid_desc_ak0_m_ak1,
2088 a_block_desc_ak0_m_ak1,
2092 a_block_slice_copy_step,
2093 b_grid_desc_bpreshuffled,
2095 b_blockwise_copy_up,
2099 b_block_slice_copy_step,
2102 num_k_block_main_loop);
2108 a_grid_desc_ak0_m_ak1,
2109 a_block_desc_ak0_m_ak1,
2113 a_block_slice_copy_step,
2114 b_grid_desc_bpreshuffled,
2118 b_block_slice_copy_step,
2120 num_k_block_main_loop);
2125 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2126 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2129 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2132 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2133 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2137 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2138 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2140 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2141 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2142 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2143 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2144 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2145 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2146 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2147 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2150 const float* p_sorted_weights_0 = p_ds_grid[
I0];
2151 const float* p_scale_b = p_ds_grid[
I1];
2153 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2154 static_assert(M4 == 4 || M4 == 8);
2158 if(p_sorted_weights_0 !=
nullptr && p_scale_b !=
nullptr)
2160 if constexpr(PerTokenQuant)
2162 constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
2163 p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
2168 p_scale_b += expert_id;
2174 const float scale_b = p_scale_b[n0 *
NWave * NPerXdl * PerTokenQuant];
2177 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2178 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2179 if constexpr(PerTokenQuant)
2183 p_sorted_token_ids + m_pos);
2185 if constexpr(MulRoutedWeight)
2188 p_ds_grid[
I2] + m_pos);
2191 float scale_a = [&]() {
2192 if constexpr(PerTokenQuant)
2195 scale_token_ids.template AsType<index_t>()[m4];
2196 const index_t token_offset = fused_token & 0xffffff;
2197 return token_offset < problem.NumTokens
2198 ? p_sorted_weights_0[IsInputGemm
2208 return p_sorted_weights_0[0];
2212 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2215 if constexpr(IsInputGemm)
2217 if constexpr(ActivationOperation == Activation::silu_and_mul)
2219 const float scale_up =
2220 p_scale_b[(n0 *
NWave * NPerXdl + problem.N) *
2222 float gate = scale_a * scale_b * c_thread_buf[cidx];
2223 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2224 if constexpr(MulRoutedWeight)
2226 gate = gate * topk_weights.template AsType<float>()[m4];
2227 up = up * topk_weights.template AsType<float>()[m4];
2235 c_thread_buf_fp32(cidx) = gate * up;
2237 else if(ActivationOperation == Activation::gelu_and_mul)
2239 const float scale_up =
2240 p_scale_b[(n0 *
NWave * NPerXdl + problem.N) *
2242 float gate = scale_a * scale_b * c_thread_buf[cidx];
2243 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2244 if constexpr(MulRoutedWeight)
2246 gate = gate * topk_weights.template AsType<float>()[m4];
2247 up = up * topk_weights.template AsType<float>()[m4];
2255 c_thread_buf_fp32(cidx) = gate * up;
2260 c_thread_buf_fp32(cidx) =
2261 scale_a * scale_b * c_thread_buf[cidx];
2262 if constexpr(MulRoutedWeight)
2264 c_thread_buf_fp32(cidx) =
2265 c_thread_buf_fp32(cidx) *
2266 topk_weights.template AsType<float>()[m4];
2280 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2281 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2282 if constexpr(MulRoutedWeight)
2285 p_ds_grid[
I2] + m_pos);
2289 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2293 if constexpr(IsInputGemm)
2295 if constexpr(ActivationOperation == Activation::silu_and_mul)
2297 float gate = c_thread_buf[cidx];
2298 float up = c_thread_buf_up[cidx];
2299 if constexpr(MulRoutedWeight)
2301 gate = gate * topk_weights.template AsType<float>()[m4];
2302 up = up * topk_weights.template AsType<float>()[m4];
2305 c_thread_buf_fp32(cidx) = gate * up;
2307 else if(ActivationOperation == Activation::gelu_and_mul)
2309 float gate = c_thread_buf[cidx];
2310 float up = c_thread_buf_up[cidx];
2311 if constexpr(MulRoutedWeight)
2313 gate = gate * topk_weights.template AsType<float>()[m4];
2314 up = up * topk_weights.template AsType<float>()[m4];
2317 c_thread_buf_fp32(cidx) = gate * up;
2322 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2323 if constexpr(MulRoutedWeight)
2325 c_thread_buf_fp32(cidx) =
2326 topk_weights.template AsType<float>()[m4] *
2327 c_thread_buf_fp32[cidx];
2336 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2340 static_cast<CShuffleDataType*
>(p_shared),
2341 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2344 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2364 const auto c_thread_mtx_on_block =
2365 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2367 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2368 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2370 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2376 const auto m_thread_data_on_block_idx =
2377 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2380 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2386 const auto n_thread_data_on_block_idx =
2387 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2391 auto c_thread_copy_vgpr_to_lds =
2394 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2395 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2397 Sequence<CShuffleMXdlPerWavePerShuffle,
2398 CShuffleNXdlPerWavePerShuffle,
2411 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2414 m_thread_data_on_block_idx[
I1],
2415 n_thread_data_on_block_idx[
I1],
2416 m_thread_data_on_block_idx[
I2],
2417 m_thread_data_on_block_idx[
I3],
2418 m_thread_data_on_block_idx[
I4],
2419 n_thread_data_on_block_idx[
I2]),
2422 using EDataType = CDataType;
2425 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2427 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2429 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2434 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2440 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2442 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2447 tie(c_shuffle_block_buf),
2449 {
return ds_grid_buf[i]; },
2453 const auto idx_c_ds_block_begin =
2463 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2464 c_grid_desc_mblock_mperblock_nblock_nperblock;
2466 using CDEBlockTransferCluster =
2467 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2468 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2469 constexpr index_t scatter_weight_idx = 3;
2474 decltype(c_ds_desc_refs),
2475 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2476 CElementwiseOperation,
2480 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2482 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2483 CDEBlockTransferCluster,
2489 CDEShuffleBlockTransferScalarPerVectors,
2501 idx_c_ds_block_begin,
2502 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2507 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2508 constexpr auto sfc_c_vgpr =
2511 Sequence<CShuffleMXdlPerWavePerShuffle,
2512 CShuffleNXdlPerWavePerShuffle,
2520 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2523 constexpr auto sfc_cde_block =
2527 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2529 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2531 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2532 constexpr auto EMThreads =
2533 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2534 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2535 constexpr auto ENThreads =
2536 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2541 auto dstidx = sfc_cde_block.GetIndex(access_id);
2543 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2545 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2546 IndexType token_offset = fused_token & 0xffffff;
2547 if constexpr(IsInputGemm)
2549 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2551 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.N;
2557 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2558 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2560 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2561 c_shuffle_block_buf);
2567 cde_block_copy_lds_and_global.Run(
2570 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2574 if constexpr(access_id < num_access - 1)
2576 constexpr auto cde_lds_and_global_step =
2577 sfc_cde_block.GetForwardStep(access_id);
2581 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2582 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2586 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2587 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2589 cde_lds_and_global_step);