104template <
typename ALayout,
108 typename AScaleDataType,
110 typename BScaleDataType,
112 typename GemmAccDataType,
113 typename CShuffleDataType,
114 typename AElementwiseOperation,
115 typename BElementwiseOperation,
116 typename CElementwiseOperation,
129 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
130 typename ABlockTransferThreadClusterArrangeOrder,
131 typename ABlockTransferSrcAccessOrder,
132 index_t ABlockTransferSrcVectorDim,
133 index_t ABlockTransferSrcScalarPerVector,
134 index_t ABlockTransferDstScalarPerVector_AK1,
135 bool ABlockLdsExtraM,
136 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
137 typename BBlockTransferThreadClusterArrangeOrder,
138 typename BBlockTransferSrcAccessOrder,
139 index_t BBlockTransferSrcVectorDim,
140 index_t BBlockTransferSrcScalarPerVector,
141 index_t BBlockTransferDstScalarPerVector_BK1,
142 bool BBlockLdsExtraN,
143 index_t CShuffleMXdlPerWavePerShuffle,
144 index_t CShuffleNXdlPerWavePerShuffle,
145 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
146 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
149 typename ComputeTypeA =
151 typename ComputeTypeB =
163 AElementwiseOperation,
164 BElementwiseOperation,
165 CElementwiseOperation>
172 template <index_t NXdlPerWave_>
184 AElementwiseOperation,
185 BElementwiseOperation,
186 CElementwiseOperation,
199 ABlockTransferThreadClusterLengths_AK0_M_AK1,
200 ABlockTransferThreadClusterArrangeOrder,
201 ABlockTransferSrcAccessOrder,
202 ABlockTransferSrcVectorDim,
203 ABlockTransferSrcScalarPerVector,
204 ABlockTransferDstScalarPerVector_AK1,
207 BBlockTransferThreadClusterLengths_BK0_N_BK1,
208 BBlockTransferThreadClusterArrangeOrder,
209 BBlockTransferSrcAccessOrder,
210 BBlockTransferSrcVectorDim,
211 BBlockTransferSrcScalarPerVector,
212 BBlockTransferDstScalarPerVector_BK1,
215 CShuffleMXdlPerWavePerShuffle,
216 CShuffleNXdlPerWavePerShuffle,
217 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
218 CShuffleBlockTransferScalarPerVector_NPerBlock,
223 template <index_t NXdlPerWave_>
235 AElementwiseOperation,
236 BElementwiseOperation,
237 CElementwiseOperation,
250 ABlockTransferThreadClusterLengths_AK0_M_AK1,
251 ABlockTransferThreadClusterArrangeOrder,
252 ABlockTransferSrcAccessOrder,
253 ABlockTransferSrcVectorDim,
254 ABlockTransferSrcScalarPerVector,
255 ABlockTransferDstScalarPerVector_AK1,
258 BBlockTransferThreadClusterLengths_BK0_N_BK1,
259 BBlockTransferThreadClusterArrangeOrder,
260 BBlockTransferSrcAccessOrder,
261 BBlockTransferSrcVectorDim,
262 BBlockTransferSrcScalarPerVector,
263 BBlockTransferDstScalarPerVector_BK1,
266 CShuffleMXdlPerWavePerShuffle,
267 CShuffleNXdlPerWavePerShuffle,
268 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
269 CShuffleBlockTransferScalarPerVector_NPerBlock,
284 using Argument =
typename GridwiseGemm64::Argument;
289 template <
typename Gr
idwiseGemm>
290 float RunImp(
const typename GridwiseGemm::Argument& arg,
293 if(stream_config.log_level_ > 0)
296 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
299 if(!GridwiseGemm::CheckValidity(arg))
301 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
305 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
309 index_t k_grain = arg.KBatch * KPerBlock;
310 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
312 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
314 const auto Run = [&](
const auto& kernel) {
315 if(stream_config.flush_cache)
319 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
320 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
321 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
322 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
325 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType);
327 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType);
330 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
331 rotating_mem.Print();
333 auto run_flush_cache = [&]() {
340 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
342 arg_.M * arg_.N *
sizeof(CDataType),
343 stream_config.stream_id_));
358 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
360 arg.M * arg.N *
sizeof(CDataType),
361 stream_config.stream_id_));
364 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
369 constexpr index_t minimum_occupancy =
372 MPerBlock * NPerBlock * KPerBlock *
sizeof(ADataType) <= 128 * 128 * 64 * 2)
377 constexpr auto TailNumChoices = []() {
379 return Tuple<constant<TailNumber::Full>>{};
381 return Tuple<constant<TailNumber::Even>, constant<TailNumber::Odd>>{};
383 static_assert(
false,
"Unexpected BlkGemmPipelineVer!");
385 constexpr bool Use2LDS = []() {
391 static_assert(
false,
"Unexpected BlkGemmPipelineVer!");
393 const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
394 using BoolChoices = Tuple<ck::true_type, ck::false_type>;
395 static_for_product<BoolChoices,
398 [&](
auto mainloop_choice,
auto KBatch_cond_choice,
auto tail_num_choice) {
399 constexpr auto CGlobalMemoryDataOperation =
402 if(mainloop_choice.value == has_main_k_block_loop &&
403 KBatch_cond_choice.value == (arg.KBatch > 1) &&
404 tail_num_choice.value == tail_num)
409 mainloop_choice.value,
410 CGlobalMemoryDataOperation,
412 tail_num_choice.value>;
424 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
430 static_assert(is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>(),
431 "Only microscaling formats are supported for ADataType and BDataType");
433 static_assert(ScaleBlockSize == 32,
"Only ScaleBlockSize 32 is supported");
436 "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType");
470 return GridwiseGemm64::CheckValidity(arg);
477 return GridwiseGemm32::CheckValidity(
478 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
491 const AScaleDataType* p_a_scale,
492 const BDataType* p_b,
493 const BScaleDataType* p_b_scale,
504 AElementwiseOperation a_element_op,
505 BElementwiseOperation b_element_op,
506 CElementwiseOperation c_element_op)
531 const void* p_a_scale,
533 const void* p_b_scale,
544 AElementwiseOperation a_element_op,
545 BElementwiseOperation b_element_op,
546 CElementwiseOperation c_element_op)
override
548 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
549 static_cast<const AScaleDataType*
>(p_a_scale),
550 static_cast<const BDataType*
>(p_b),
551 static_cast<const BScaleDataType*
>(p_b_scale),
552 static_cast<CDataType*
>(p_c),
570 return std::make_unique<Invoker>(
Invoker{});
576 auto str = std::stringstream();
578 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
582 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
590 str <<
"DeviceGemmMX_Xdl_CShuffleV3"
593 << std::string(ALayout::name)[0]
594 << std::string(BLayout::name)[0]
595 << std::string(CLayout::name)[0]
600 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
602 << MPerXDL<<
"x"<<NPerXDL <<
", "
604 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
606 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
607 <<
"BlkGemmPipelineScheduler: "
608 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
609 <<
"BlkGemmPipelineVersion: "
610 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
611 <<
"BlkGemmPipelinePrefetchStages: "
612 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages <<
", "
614 << GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride <<
", "
615 <<
"ScaleBlockSize: "
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__global__ enable_if_t<!Use2LDS, void > kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:156
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:156
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:288
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:421
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:290
WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types.
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:166
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:428
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:441
conditional_t< !is_same_v< BLayout, tensor_layout::gemm::MFMA >, GridwiseGemmMXBase< math::max(NXdlPerWave64, 1)>, GridwiseGemmMXBPreshuffleBase< math::max(NXdlPerWave64, 1)> > GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:275
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:574
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:169
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideScaleA, ck::index_t StrideB, ck::index_t StrideScaleB, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:530
static auto MakeArgument(const ADataType *p_a, const AScaleDataType *p_a_scale, const BDataType *p_b, const BScaleDataType *p_b_scale, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:490
GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmMXBase
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:173
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:168
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:568
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:527
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:284
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:485
GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmMXBPreshuffleBase
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:224
conditional_t< !is_same_v< BLayout, tensor_layout::gemm::MFMA >, GridwiseGemmMXBase< NXdlPerWave32 >, GridwiseGemmMXBPreshuffleBase< NXdlPerWave32 > > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:279
Definition device_gemm_mx.hpp:25
Definition flush_cache.hpp:299