device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp Source File#
device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const index_t grid_size_grp, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:41
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition functional2.hpp:33
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition device_base.hpp:249
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition device_base.hpp:247
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:322
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops()=default
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:369
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:402
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const CGridDesc_M_N &c_grid_desc_m_n, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:346
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:363
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:375
static constexpr auto I0
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:323
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:353
static constexpr auto I1
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:324
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, index_t N, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:337
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:273
Block2ETileMap block_to_ctile_map_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:315
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:285
index_t id_off_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:317
index_t block_start_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:316
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:297
__host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:304
UnderlyingBlockToCTileMap underlying_type
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:274
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:310
__host__ __device__ OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off=0)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:276
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:419
std::array< const void *, NumATensor > as_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:421
std::array< const void *, NumBTensor > bs_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:422
void * e_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:424
index_t M_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:426
std::array< index_t, NumBTensor > StrideBs_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:428
index_t K_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:426
std::array< index_t, NumDTensor > StrideDs_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:429
std::array< const void *, NumDTensor > ds_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:423
std::array< index_t, NumATensor > StrideAs_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:427
index_t N_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:426
index_t StrideE_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:430
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:435
Argument(std::vector< std::array< const void *, NumATensor > > &, std::vector< std::array< const void *, NumBTensor > > &, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &, std::vector< GemmMultiABDDesc > &gemm_descs, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{})
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:439
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:573
index_t sum_of_m
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:582
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:570
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:574
index_t grid_size_grp_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:580
index_t grid_size_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:579
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:569
CDEElementwiseOperation c_element_op_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:571
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:575
const void * grouped_gemm_kernel_args_dev
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:577
index_t group_count_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:567
void UpdateKBatch(index_t)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:437
index_t barrier_size_grp_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:581
index_t k_batch_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:584
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:589
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:593
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:682
DeviceOp::Argument Argument
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:590
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:207
static constexpr auto I0
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:217
static constexpr auto I2
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:219
static auto MakeArgument(std::vector< std::array< const void *, NumATensor > > &p_As, std::vector< std::array< const void *, NumBTensor > > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmMultiABDDesc > gemm_descs, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{})
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:732
static constexpr index_t NumBTensor
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:214
static constexpr index_t NumGemmKPrefetchStage
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:221
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:828
static constexpr index_t NumATensor
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:213
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< std::array< const void *, NumATensor > > &p_As, std::vector< std::array< const void *, NumBTensor > > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmMultiABDDesc > &gemm_descs, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{}) override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:749
static constexpr auto I1
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:218
void SetElementwiseOps(BaseArgument *p_arg, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:818
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops< MPerBlock, NPerBlock > Block2ETileMap
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:415
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:763
static void SetKBatch(Argument &arg, index_t k_batch)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:856
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:211
static auto MakeInvoker()
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:745
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:269
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:727
void SetKBatch(BaseArgument *p_arg, index_t k_batch) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:859
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:210
OffsettedBlockToCTileMapMLoops< Block2ETileMap > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:416
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:689
std::string GetTypeString() const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:769
void SetDeviceKernelArgs(BaseArgument *p_arg, const void *kernel_args) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:813
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK DeviceOp
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:208
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:270
GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:225
static void SetDeviceKernelArgs(Argument &arg, const void *kernel_args)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:807
static void SetElementwiseOps(Argument &arg, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:797
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:215
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:73
Definition device_grouped_gemm.hpp:80
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:17