24template <
typename AsLayout,
30 typename GemmAccDataType,
31 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
49 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
52 index_t ABlockTransferSrcVectorDim,
53 index_t ABlockTransferSrcScalarPerVector,
54 index_t ABlockTransferDstScalarPerVector_AK1,
56 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 index_t BBlockTransferSrcVectorDim,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t BBlockTransferDstScalarPerVector_BK1,
63 index_t CShuffleMXdlPerWavePerShuffle,
64 index_t CShuffleNXdlPerWavePerShuffle,
65 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69 typename ComputeTypeA = CDataType,
70 typename ComputeTypeB = ComputeTypeA>
79 AElementwiseOperation,
80 BElementwiseOperation,
81 CElementwiseOperation>
95 template <index_t NXdlPerWave_>
106 AElementwiseOperation,
107 BElementwiseOperation,
108 CElementwiseOperation,
120 ABlockTransferThreadClusterLengths_AK0_M_AK1,
121 ABlockTransferThreadClusterArrangeOrder,
122 ABlockTransferSrcAccessOrder,
123 ABlockTransferSrcVectorDim,
124 ABlockTransferSrcScalarPerVector,
125 ABlockTransferDstScalarPerVector_AK1,
128 BBlockTransferThreadClusterLengths_BK0_N_BK1,
129 BBlockTransferThreadClusterArrangeOrder,
130 BBlockTransferSrcAccessOrder,
131 BBlockTransferSrcVectorDim,
132 BBlockTransferSrcScalarPerVector,
133 BBlockTransferDstScalarPerVector_BK1,
136 CShuffleMXdlPerWavePerShuffle,
137 CShuffleNXdlPerWavePerShuffle,
138 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
139 CShuffleBlockTransferScalarPerVector_NPerBlock,
147 using Argument =
typename GridwiseGemm64::Argument;
152 template <
typename Gr
idwiseGemm>
153 float RunImp(
const typename GridwiseGemm::Argument& arg,
156 if(stream_config.log_level_ > 0)
161 if(!GridwiseGemm::CheckValidity(arg))
163 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
167 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
171 index_t k_grain = arg.KBatch * KPerBlock;
172 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
174 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
176 const auto Run = [&](
const auto& kernel) {
178 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
180 arg.M * arg.N *
sizeof(CDataType),
181 stream_config.stream_id_));
184 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
187 constexpr index_t minimum_occupancy =
190 if(has_main_k_block_loop)
223 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
233 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
245 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
247 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
259 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
261 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
274 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
276 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
289 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
291 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
304 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
318 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
320 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
336 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
346 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
358 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
360 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
372 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
374 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
387 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
389 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
402 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
404 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
417 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
419 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
431 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
433 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
453 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
477 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
504 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
528 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
588 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
624 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
637 std::array<const void*, NumBTensor> p_bs,
638 std::array<const void*, NumDTensor> p_ds,
643 std::array<index_t, NumATensor> StrideAs,
644 std::array<index_t, NumBTensor> StrideBs,
645 std::array<index_t, NumDTensor> StrideDs,
647 AElementwiseOperation a_element_op,
648 BElementwiseOperation b_element_op,
649 CElementwiseOperation c_element_op)
691 std::array<const void*, NumBTensor> p_bs,
692 std::array<const void*, NumDTensor> p_ds,
697 std::array<ck::index_t, NumATensor> StrideAs,
698 std::array<ck::index_t, NumBTensor> StrideBs,
699 std::array<ck::index_t, NumDTensor> StrideDs,
701 AElementwiseOperation a_element_op,
702 BElementwiseOperation b_element_op,
703 CElementwiseOperation c_element_op)
override
705 return std::make_unique<Argument>(p_as,
725 return std::make_unique<Invoker>(
Invoker{});
731 auto str = std::stringstream();
733 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
737 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
745 str <<
"DeviceGemmXdlUniversal"
748 << std::string(ALayout::name)[0]
749 << std::string(BLayout::name)[0]
750 << std::string(CLayout::name)[0]
755 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
757 << MPerXDL<<
"x"<<NPerXDL <<
", "
759 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
761 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
762 <<
"BlkGemmPipelineScheduler: "
763 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
764 <<
"BlkGemmPipelineVersion: "
765 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
766 <<
"BlkGemmPipelinePrefetchStages: "
767 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, AsDataType, BsDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:151
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:153
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:585
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:82
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:723
std::string GetTypeString() const override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:729
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:147
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, AsDataType, BsDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:96
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:690
static auto MakeArgument(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, std::array< index_t, NumATensor > StrideAs, std::array< index_t, NumBTensor > StrideBs, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:636
remove_cvref_t< tuple_element_t< 0, AsLayout > > ALayout
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:91
static constexpr index_t NumATensor
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:87
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:598
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:85
static constexpr index_t NumBTensor
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:88
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:592
remove_cvref_t< tuple_element_t< 0, BsLayout > > BLayout
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:92
static auto MakeInvoker()
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:687
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:631
static constexpr index_t NumDTensor
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:89
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:84
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:145
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:144
Definition device_gemm_multiple_abd.hpp:34