device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp Source File#
device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
Go to the documentation of this file.
116 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
118 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
198 std::array<ck::index_t, array_size> qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size};
199 std::array<ck::index_t, array_size> qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1};
201 std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length};
202 std::array<ck::index_t, array_size> v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size};
204 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size};
205 std::array<ck::index_t, array_size> c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
214 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
234 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
236 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
332 std::array<ck::index_t, array_size> q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size};
333 std::array<ck::index_t, array_size> q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
335 std::array<ck::index_t, array_size> k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size};
336 std::array<ck::index_t, array_size> k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1};
338 std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length};
339 std::array<ck::index_t, array_size> v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size};
341 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size};
342 std::array<ck::index_t, array_size> c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
371 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
373 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_wmma_cross_attention_forward(const QDataType *__restrict__ p_q_grid, const KVDataType *__restrict__ p_kv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:315
__global__ void kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:45
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
__global__ void kernel_wmma_self_attention_forward(const QKVDataType *__restrict__ p_qkv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:183
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm_arraybase.hpp:172
static constexpr auto matrix_padder
Definition transform_contraction_to_gemm_arraybase.hpp:140
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:228
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:156
__host__ static __device__ constexpr auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:193
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:340
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:679
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:696
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:691
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:701
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const B0GridDesc_G_L_K &b0_grid_desc_g_l_k, const B1GridDesc_G_N_L &b1_grid_desc_g_n_l, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:680
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:706
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1112
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::Argument::p_c_grid_
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1203
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1211
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1242
const B0DataType * p_b0_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1201
AccElementwiseOperation acc_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1225
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1202
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1226
std::array< index_t, NumDimG+NumDimM+NumDimN > a_mz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1235
std::array< index_t, NumDimG+NumDimM+NumDimN > c_mz_nz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1238
Argument(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, const index_t M01, const index_t N01, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1113
B0GridDesc_G_L_K b0_grid_desc_g_l_k_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1212
CElementwiseOperation c_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1227
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1220
C0MatrixMask c0_matrix_mask_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1230
B0ElementwiseOperation b0_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1224
AGridDesc a_grid_desc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1206
B1GridDesc b1_grid_desc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1208
B1GridDesc_G_N_L b1_grid_desc_g_n_l_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1213
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::Argument::p_a_grid_
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1200
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1209
AElementwiseOperation a_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1223
std::array< index_t, NumDimG+NumDimM+NumDimN > raw_lengths_mz_lz_kz_nz_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1234
CGridDesc_G_M_N c_grid_desc_g_m_n_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1214
B0GridDesc b0_grid_desc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1207
index_t batch_count_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1240
std::array< index_t, NumDimG+NumDimM+NumDimN > b0_lz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1236
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1217
std::array< index_t, NumDimG+NumDimM+NumDimN > b1_nz_lz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1237
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1054
const ADataType * p_q_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1076
float alpha_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1086
const B0DataType * p_kv_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1077
CDataType * p_out_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1078
index_t q_sequence_length_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1082
index_t kv_sequence_length_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1083
index_t head_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1085
CrossAttnArg(const ADataType *p_q_grid, const B0DataType *p_kv_grid, CDataType *p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1055
index_t batch_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1081
index_t head_count_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1084
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1306
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1309
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1357
DeviceOp::CrossAttnArg Argument
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1307
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1367
DeviceOp::RawArg Argument
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1368
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1424
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1370
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:795
const B0DataType * p_b0_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:826
index_t K_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:833
bool output_permute_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:839
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:827
index_t G1_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:836
float alpha_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:837
index_t O_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:834
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:828
RawArg(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:796
index_t N_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:832
index_t M_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:831
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:825
index_t G0_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:835
bool input_permute_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:838
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1012
SelfAttnArg(const ADataType *p_qkv_grid, CDataType *p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1013
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::SelfAttnArg::alpha_
float alpha_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1038
index_t sequence_length_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1035
index_t head_count_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1036
const ADataType * p_qkv_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1030
index_t head_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1037
index_t batch_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1034
CDataType * p_out_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1031
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1247
DeviceOp::SelfAttnArg Argument
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1248
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1250
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1295
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:528
static constexpr index_t NumAcc0Bias
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:532
static constexpr auto B0EnableLds_manu
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:566
__host__ __device__ static constexpr auto make_MaskOutPredicate()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:665
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_L
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:662
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) B0GridDesc_G_L_K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:661
static constexpr auto I4
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:551
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:660
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< NumDimG, NumDimM, NumDimL, NumDimK, NumDimN >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, ASpec, B0Spec, B1Spec, CSpec > Transform
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:573
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:631
static constexpr index_t NumDimGemm1N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:542
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:656
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1674
static auto MakeArgument(const ADataType *p_a, const B0DataType *p_b0, const B1DataType *p_b1, CDataType *p_c, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:842
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:606
std::string GetTypeString() const override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1680
static constexpr auto B1EnableLds
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:571
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:663
static constexpr auto MWaves
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:557
static constexpr auto WmmaK
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:555
static auto MakeCrossAttnArgument(const ADataType *p_q_grid, const B0DataType *p_kv_grid, CDataType *p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1089
static constexpr auto B0EnableLds
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:570
static auto MakeCrossAttnInvoker()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1364
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle DeviceOp
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:545
static constexpr index_t NumDimGemm1M
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:541
static auto MakeInvoker()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1671
static constexpr index_t NumDimGemm0N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:539
static constexpr index_t NumDimGemm0K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:540
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::IsSupportedArgument
static bool IsSupportedArgument(const RawArg &arg)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:860
static constexpr auto I2
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:549
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::IsSupportedArgument
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1006
static constexpr index_t NumDimGemm0M
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:538
static constexpr auto I5
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:552
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::MakeArgumentPointer
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b0_gs_ls_ks_lengths, const std::vector< index_t > &b0_gs_ls_ks_strides, const std::vector< index_t > &b1_gs_ns_ls_lengths, const std::vector< index_t > &b1_gs_ns_ls_strides, const std::vector< index_t > &c_gs_ms_ns_lengths, const std::vector< index_t > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1579
static auto MakeSelfAttnArgument(const ADataType *p_qkv_grid, CDataType *p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1041
static constexpr auto B1EnableLds_manu
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:567
static constexpr index_t NumDimGemm1K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:543
static constexpr auto I6
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:553
static constexpr auto LWaves
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:558
static constexpr auto AEnableLds
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:569
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:676
static constexpr auto B0EnableLds_auto
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:562
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:657
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:658
static constexpr auto I3
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:550
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1431
static constexpr auto AEnableLds_auto
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:561
static constexpr auto I0
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:547
static constexpr auto I1
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:548
static constexpr index_t NumAcc1Bias
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:533
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::MakeSelfAttnInvoker
static auto MakeSelfAttnInvoker()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1302
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle::MakeAGridDescriptor
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:582
static constexpr auto NWaves
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:559
GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:719
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:659
static constexpr auto B1EnableLds_auto
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:563
static constexpr auto AEnableLds_manu
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:565
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43