23template <
typename ADataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
56 bool BBlockLdsAddExtraN,
57 index_t CShuffleMRepeatPerShuffle,
58 index_t CShuffleNRepeatPerShuffle,
59 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
61 typename ComputeType = CDataType,
64 typename LDSTypeA = ComputeType,
65 typename LDSTypeB = ComputeType>
73 AElementwiseOperation,
74 BElementwiseOperation,
75 CElementwiseOperation,
93 template <index_t NXdlPerWave_>
103 AElementwiseOperation,
104 BElementwiseOperation,
105 CElementwiseOperation,
116 ABlockTransferThreadClusterLengths_K0_M_K1,
117 ABlockTransferThreadClusterArrangeOrder,
118 ABlockTransferSrcAccessOrder,
119 ABlockTransferSrcVectorDim,
120 ABlockTransferSrcScalarPerVector,
121 ABlockTransferDstScalarPerVector_K1,
124 BBlockTransferThreadClusterLengths_K0_N_K1,
125 BBlockTransferThreadClusterArrangeOrder,
126 BBlockTransferSrcAccessOrder,
127 BBlockTransferSrcVectorDim,
128 BBlockTransferSrcScalarPerVector,
129 BBlockTransferDstScalarPerVector_K1,
132 CShuffleMRepeatPerShuffle,
133 CShuffleNRepeatPerShuffle,
134 CBlockTransferScalarPerVector_NWaveNPerXDL,
135 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
148 const BDataType* p_b_grid_,
149 CDataType* p_c_grid_,
161 AElementwiseOperation a_element_op_,
162 BElementwiseOperation b_element_op_,
163 CElementwiseOperation c_element_op_)
197 template <
typename Gr
idwiseGemm>
200 if(stream_config.log_level_ > 0)
205 typename GridwiseGemm::Argument karg(arg.p_a_grid,
219 const auto kbatch = karg.k_batch;
220 if(!GridwiseGemm::CheckValidity(karg))
222 throw std::runtime_error(
223 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
229 ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
230 const auto K0Padded = karg.K0Padded;
232 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
236 const auto Run = [&](
const auto& kernel) {
238 hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
240 karg.M * karg.N *
sizeof(CDataType),
241 stream_config.stream_id_));
249 static_cast<typename GridwiseGemm::Argument
>(karg),
256 if(has_main_k0_block_loop)
265 AElementwiseOperation,
266 BElementwiseOperation,
267 CElementwiseOperation>;
278 AElementwiseOperation,
279 BElementwiseOperation,
280 CElementwiseOperation>;
294 AElementwiseOperation,
295 BElementwiseOperation,
296 CElementwiseOperation>;
307 AElementwiseOperation,
308 BElementwiseOperation,
309 CElementwiseOperation>;
324 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
357 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(karg));
370 const BDataType* p_b,
378 AElementwiseOperation a_element_op,
379 BElementwiseOperation b_element_op,
380 CElementwiseOperation c_element_op,
414 AElementwiseOperation a_element_op,
415 BElementwiseOperation b_element_op,
416 CElementwiseOperation c_element_op,
419 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
420 static_cast<const BDataType*
>(p_b),
421 static_cast<CDataType*
>(p_c),
441 return std::make_unique<Invoker>(
Invoker{});
447 auto str = std::stringstream();
449 std::map<LoopScheduler, std::string> LoopSchedToString{
452 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
456 <<
", PipelineVersion: " << PipelineVersionToString[PipelineVer];
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
__global__ void kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:33
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:440
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateMPadded __host__ static __device__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:196
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:213
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateK0Padded __host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:206
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap())> DefaultBlock2CTileMap
Definition gridwise_gemm_xdlops_v2r4r2.hpp:662
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CalculateNPadded __host__ static __device__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:201
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::GetTypeString static std::string GetTypeString()
Definition gridwise_gemm_xdlops_v2r4r2.hpp:1104
Definition device_base.hpp:197
Definition device_gemm_splitk.hpp:26
Definition device_gemm_xdl_splitk_c_shuffle.hpp:146
AElementwiseOperation a_element_op
Definition device_gemm_xdl_splitk_c_shuffle.hpp:184
BElementwiseOperation b_element_op
Definition device_gemm_xdl_splitk_c_shuffle.hpp:185
CElementwiseOperation c_element_op
Definition device_gemm_xdl_splitk_c_shuffle.hpp:186
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:147
Definition device_gemm_xdl_splitk_c_shuffle.hpp:193
void Print(const Argument &karg)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:195
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_splitk_c_shuffle.hpp:198
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:321
Definition device_gemm_xdl_splitk_c_shuffle.hpp:77
ComputeType ComputeTypeA
Definition device_gemm_xdl_splitk_c_shuffle.hpp:90
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:369
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_splitk_c_shuffle.hpp:142
static constexpr auto I1
Definition device_gemm_xdl_splitk_c_shuffle.hpp:83
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:334
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_splitk_c_shuffle.hpp:328
static constexpr auto I0
Definition device_gemm_xdl_splitk_c_shuffle.hpp:82
typename GridwiseGemm64::DefaultBlock2CTileMap DefaultBlock2CTileMap
Definition device_gemm_xdl_splitk_c_shuffle.hpp:189
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:439
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_xdl_splitk_c_shuffle.hpp:94
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch=1) override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:405
static constexpr auto I2
Definition device_gemm_xdl_splitk_c_shuffle.hpp:84
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_splitk_c_shuffle.hpp:80
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:364
static constexpr auto I3
Definition device_gemm_xdl_splitk_c_shuffle.hpp:85
std::string GetTypeString() const override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:445
static auto MakeInvoker()
Definition device_gemm_xdl_splitk_c_shuffle.hpp:402
ComputeType ComputeTypeB
Definition device_gemm_xdl_splitk_c_shuffle.hpp:91
static constexpr index_t NumGemmKPrefetchStage
Definition device_gemm_xdl_splitk_c_shuffle.hpp:88
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_splitk_c_shuffle.hpp:143
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_splitk_c_shuffle.hpp:79