device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp Source File#
device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDesc_M_N d0s_griddesc_m_n, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:48
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Gridwise gemm + softmax + gemm fusion.
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:87
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:231
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:319
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:334
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:361
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:356
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:351
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:335
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:346
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:451
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:503
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:509
index_t batch_count_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:511
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:512
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:500
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:505
AElementwiseOperation a_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:506
std::vector< index_t > raw_lengths_m_n_k_o_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:518
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:502
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:499
CElementwiseOperation c_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:510
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:497
C0MatrixMask c0_matrix_mask_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:515
const BDataType * p_b_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:498
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:452
BElementwiseOperation b_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:507
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:504
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:501
AccElementwiseOperation acc_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:508
remove_cvref_t< decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))> B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:955
constexpr bool IsValid() const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1104
BElementwiseOperation b_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1037
AElementwiseOperation a_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1036
B1ElementwiseOperation b1_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1038
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1028
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor &b1_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:928
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1031
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1026
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor &c_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:946
bool is_valid
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1042
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor &b_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:910
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor &a_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:892
C0MatrixMask c0_matrix_mask
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1030
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))> BGridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:953
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle > GridwiseGemmBase
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:961
constexpr Descriptor(ADesc a, BDesc b, B1Desc b1, CDesc c, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, B1ElementwiseOperation b1_element_op_, CElementwiseOperation c_element_op_)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1044
GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1033
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1023
bool has_main_k_block_loop
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1041
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(CDesc{}))> CGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:957
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))> AGridDesc_AK0_M_AK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:951
CElementwiseOperation c_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1039
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle::Descriptor::c_grid_desc_m_n
CGridDesc_M_N c_grid_desc_m_n
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1029
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1027
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1024
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:523
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:527
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:626
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:606
DeviceOp::Argument Argument
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:524
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:210
static constexpr auto I2
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:220
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:858
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_b1, void *p_c, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:810
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:641
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:445
decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)) B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:375
static constexpr auto I1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:219
static constexpr auto matrix_padder
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:222
static constexpr auto make_descriptor(ADesc a, BDesc b, B1Desc b1, CDesc c, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, B1ElementwiseOperation b1_element_op=B1ElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1109
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:212
static constexpr auto I0
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:218
static constexpr auto MXdlPerWave32
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:215
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:374
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:285
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:373
std::string GetTypeString() const override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:864
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:446
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle > GridwiseGemmBase
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:384
static constexpr auto MXdlPerWave64
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:213
conditional_t< MaskOutUpperTriangle, C0MatrixMask_impl< MaskOutUpperTrianglePredicate >, C0MatrixMask_impl< MaskDisabledPredicate > > C0MatrixMask
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:378
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:226
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:376
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const B1DataType *p_b1, CDataType *p_c, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:777
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle::IsValidCompilationParameter
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:634
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:772
static auto MakeInvoker()
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:807
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:255
static __device__ void Run(const Desc &desc, const float scale, const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, const ADataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1123
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:315
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:731
Definition device_batched_gemm_softmax_gemm.hpp:31
Definition matrix_padder.hpp:63