24template <
typename ALayout,
29 typename AScaleDataType,
31 typename BScaleDataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
60 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61 typename BBlockTransferThreadClusterArrangeOrder,
62 typename BBlockTransferSrcAccessOrder,
63 index_t BBlockTransferSrcVectorDim,
64 index_t BBlockTransferSrcScalarPerVector,
65 index_t BBlockTransferDstScalarPerVector_BK1,
67 index_t CShuffleMXdlPerWavePerShuffle,
68 index_t CShuffleNXdlPerWavePerShuffle,
69 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70 typename CDEShuffleBlockTransferScalarPerVectors,
73 typename ComputeTypeA = CDataType,
74 typename ComputeTypeB = ComputeTypeA,
75 typename LDSTypeA = ComputeTypeA,
76 typename LDSTypeB = ComputeTypeB>
91 AElementwiseOperation,
92 BElementwiseOperation,
93 CElementwiseOperation>
101 template <index_t NXdlPerWave_>
113 AElementwiseOperation,
114 BElementwiseOperation,
115 CElementwiseOperation,
130 ABlockTransferThreadClusterLengths_AK0_M_AK1,
131 ABlockTransferThreadClusterArrangeOrder,
132 ABlockTransferSrcAccessOrder,
133 ABlockTransferSrcVectorDim,
134 ABlockTransferSrcScalarPerVector,
135 ABlockTransferDstScalarPerVector_AK1,
138 BBlockTransferThreadClusterLengths_BK0_N_BK1,
139 BBlockTransferThreadClusterArrangeOrder,
140 BBlockTransferSrcAccessOrder,
141 BBlockTransferSrcVectorDim,
142 BBlockTransferSrcScalarPerVector,
143 BBlockTransferDstScalarPerVector_BK1,
146 CShuffleMXdlPerWavePerShuffle,
147 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
148 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
149 CDEShuffleBlockTransferScalarPerVectors,
159 using Argument =
typename GridwiseGemm64::Argument;
164 template <
typename Gr
idwiseGemm>
165 float RunImp(
const typename GridwiseGemm::Argument& arg,
168 if(stream_config.log_level_ > 0)
173 if(!GridwiseGemm::CheckValidity(arg))
175 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
179 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
183 index_t k_grain = arg.KBatch * KPerBlock;
184 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
186 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
188 const auto Run = [&](
const auto& kernel) {
189 if(stream_config.flush_cache)
193 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
194 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
195 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
196 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
199 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType);
201 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType);
204 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
205 rotating_mem.Print();
207 auto run_flush_cache = [&]() {
214 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
216 arg_.M * arg_.N *
sizeof(CDataType),
217 stream_config.stream_id_));
232 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
234 arg.M * arg.N *
sizeof(CDataType),
235 stream_config.stream_id_));
238 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
242 constexpr index_t minimum_occupancy = [&]() {
253 MPerBlock * NPerBlock / BlockSize > 64)
259 if(has_main_k_block_loop)
289 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
310 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
316 auto& arg = *
dynamic_cast<Argument*
>(base_arg);
359 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
373 std::array<const void*, NumDTensor> p_ds,
380 const std::array<index_t, NumDTensor> StrideDs,
382 const void* p_a_scale,
383 const void* p_b_scale,
384 AElementwiseOperation a_element_op,
385 BElementwiseOperation b_element_op,
386 CElementwiseOperation c_element_op)
388 return Argument{
static_cast<const ADataType*
>(p_a),
389 static_cast<const BDataType*
>(p_b),
391 static_cast<CDataType*
>(p_c),
399 static_cast<const AScaleDataType*
>(p_a_scale),
400 static_cast<const BScaleDataType*
>(p_b_scale),
410 std::unique_ptr<BaseArgument>
413 std::array<const void*, NumDTensor> p_ds,
420 const std::array<ck::index_t, NumDTensor> StrideDs,
422 const void* p_a_scale,
423 const void* p_b_scale,
424 AElementwiseOperation a_element_op,
425 BElementwiseOperation b_element_op,
426 CElementwiseOperation c_element_op)
override
428 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
429 static_cast<const BDataType*
>(p_b),
431 static_cast<CDataType*
>(p_c),
439 static_cast<const AScaleDataType*
>(p_a_scale),
440 static_cast<const BScaleDataType*
>(p_b_scale),
450 return std::make_unique<Invoker>(
Invoker{});
456 auto str = std::stringstream();
458 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
462 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
468 str <<
"DeviceGemmXdlUniversal"
471 << std::string(ALayout::name)[0]
472 << std::string(BLayout::name)[0]
473 << std::string(CLayout::name)[0]
478 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
480 << MPerXDL<<
"x"<<NPerXDL <<
", "
482 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
484 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
485 <<
"BlkGemmPipelineScheduler: "
486 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
487 <<
"BlkGemmPipelineVersion: "
488 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
489 <<
"BlkGemmPipelinePrefetchStages: "
490 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:118
ck::GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:994
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:163
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:165
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:307
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:94
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:448
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:97
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:371
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:96
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:326
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:98
GridwiseGemmBase< math::max(NXdlPerWave32, 1)> GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:157
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< ck::index_t, NumDTensor > StrideDs, const index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:411
GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:102
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:454
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:320
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:156
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:366
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:407
void SetKBatch(BaseArgument *base_arg, int KBatch) const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:314
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp:159
Definition device_gemm_multiple_d_ab_scale.hpp:39
Definition flush_cache.hpp:299