BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ > Struct Template Reference
#include <block_fmha_pipeline_qr_ks_vs_async_trload.hpp>
Public Types | |
| using | Problem = remove_cvref_t<Problem_> |
| using | Policy = remove_cvref_t<Policy_> |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
| using | SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | PDataType = remove_cvref_t<typename Problem::PDataType> |
| using | OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
| using | VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Public Member Functions | |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &lse_acc_dram_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void *smem_ptr) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &__restrict__ q_dram_block_window_tmp, const KDramBlockWindowTmp &__restrict__ k_dram_block_window_tmp, const VDramBlockWindowTmp &__restrict__ v_dram_block_window_tmp, const BiasDramBlockWindowTmp &__restrict__ bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &__restrict__ lse_acc_dram_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void *__restrict__ smem_ptrk0, void *__restrict__ smem_ptrk1, void *__restrict__ smem_ptrv0, void *__restrict__ smem_ptrv1) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr auto | I0 = number<0>{} |
| static constexpr auto | I1 = number<1>{} |
| static constexpr bool | kQLoadOnce = true |
| static constexpr bool | kKLoadOnce = BlockFmhaShape::kM0 >= 64 |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kN1 = BlockFmhaShape::kN1 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr index_t | kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim |
| static constexpr index_t | kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1) |
| static constexpr index_t | kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1) |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = Problem::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = Problem::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV |
| static constexpr bool | kHasLogitsSoftCap = Problem::kHasLogitsSoftCap |
| static constexpr bool | kHasDropout = Problem::kHasDropout |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr bool | kHasUnevenSplits = true |
| static constexpr index_t | kAlignmentQ = Policy::template GetAlignmentQ<Problem>() |
| static constexpr index_t | kAlignmentK = Policy::template GetAlignmentK<Problem>() |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentOacc = Policy::template GetAlignmentO<Problem>() |
| static constexpr index_t | kAlignmentBias |
| static constexpr index_t | kBlockPerCu |
| static constexpr const char * | name = "qr_async_trload" |
Member Typedef Documentation
◆ AttentionVariant
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
◆ BiasDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BlockFmhaShape
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ FmhaMask
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ KDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ LSEDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ PDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType> |
◆ Policy
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_> |
◆ Problem
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_> |
◆ QDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ RandValOutputDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
◆ SaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
◆ SMPLComputeDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
◆ VDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VLayout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Member Function Documentation
◆ GetSmemSize()
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()() [1/2]
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding>
|
inline |
NOTICE: bias might be materialized mask including -inf values, need consideration
◆ operator()() [2/2]
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding>
|
inline |
NOTICE: bias might be materialized mask including -inf values, need consideration
Member Data Documentation
◆ BiasEnum
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ I0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ I1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
◆ kAlignmentK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentOacc
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()
◆ kBlockPerCu
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
{
return 2;
}
{
return 3;
}
{
return 1;
else
return 2;
}
{
return 1;
}
else
{
return 1;
}
}
}()
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:46
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:49
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:44
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:69
◆ kBlockSize
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kHasDropout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kHasLogitsSoftCap
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kHasUnevenSplits
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kKLoadOnce
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kN1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kNWarp
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kNXdl
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
Problem::kPadHeadDimQ
◆ kPadHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
Problem::kPadHeadDimV
◆ kPadSeqLenK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kQLoadOnce
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kStoreLSE
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ kSubQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: