BlockFmhaFwdV3Pipeline< Problem_, Policy_ > Struct Template Reference#
#include <block_fmha_fwd_v3_pipeline.hpp>
Public Types | |
| using | Problem = ck_tile::remove_cvref_t<Problem_> |
| using | Policy = ck_tile::remove_cvref_t<Policy_> |
| using | QDataType = ck_tile::remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = ck_tile::remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType> |
| using | SaccDataType = ck_tile::remove_cvref_t<typename Problem::SaccDataType> |
| using | SMPLComputeDataType = ck_tile::remove_cvref_t<typename Problem::SMPLComputeDataType> |
| using | LSEDataType = ck_tile::remove_cvref_t<typename Problem::LSEDataType> |
| using | PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType> |
| using | OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType> |
| using | FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask> |
| using | BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape> |
Public Member Functions | |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction> | |
| CK_TILE_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp> | |
| CK_TILE_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
| template<ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock> | |
| static CK_TILE_DEVICE constexpr auto | MakeSimpleLdsDesc () |
| template<ck_tile::index_t MPerBlock> | |
| static CK_TILE_DEVICE constexpr auto | MakeSimpleLdsDesc1D () |
| template<typename DataType, typename Descriptor> | |
| static CK_TILE_DEVICE constexpr auto | make_lds_tile_window (void *base, const Descriptor &desc) |
| template<uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7> | |
| static CK_TILE_DEVICE constexpr void | s_waitcnt () |
| template<uint16_t Vmcnt> | |
| static CK_TILE_DEVICE constexpr void | s_waitcnt_vmcnt () |
| template<uint8_t Lgkmcnt> | |
| static CK_TILE_DEVICE constexpr void | s_waitcnt_lgkmcnt () |
Static Public Attributes | |
| static constexpr ck_tile::index_t | kBlockSize = Problem::kBlockSize |
| static constexpr ck_tile::index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr ck_tile::index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr ck_tile::index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr ck_tile::index_t | kN1 = BlockFmhaShape::kN1 |
| static constexpr ck_tile::index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr ck_tile::index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr ck_tile::index_t | kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = Problem::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = Problem::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = Problem::kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV = Problem::kPadHeadDimV |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr ck_tile::index_t | kAlignmentQ |
| static constexpr ck_tile::index_t | kAlignmentK |
| static constexpr ck_tile::index_t | kAlignmentV |
| static constexpr ck_tile::index_t | kAlignmentO |
| static constexpr ck_tile::index_t | kBlockPerCu |
Member Typedef Documentation
◆ BlockFmhaShape
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ FmhaMask
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask> |
◆ KDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::KDataType = ck_tile::remove_cvref_t<typename Problem::KDataType> |
◆ LSEDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::LSEDataType = ck_tile::remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType> |
◆ PDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType> |
◆ Policy
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::Policy = ck_tile::remove_cvref_t<Policy_> |
◆ Problem
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::Problem = ck_tile::remove_cvref_t<Problem_> |
◆ QDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::QDataType = ck_tile::remove_cvref_t<typename Problem::QDataType> |
◆ SaccDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::SaccDataType = ck_tile::remove_cvref_t<typename Problem::SaccDataType> |
◆ SMPLComputeDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::SMPLComputeDataType = ck_tile::remove_cvref_t<typename Problem::SMPLComputeDataType> |
◆ VDataType
| using ck_tile::BlockFmhaFwdV3Pipeline< Problem_, Policy_ >::VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType> |
Member Function Documentation
◆ GetSmemSize()
|
inlinestaticconstexpr |
◆ make_lds_tile_window()
|
inlinestaticconstexpr |
◆ MakeSimpleLdsDesc()
|
inlinestaticconstexpr |
◆ MakeSimpleLdsDesc1D()
|
inlinestaticconstexpr |
◆ operator()() [1/2]
|
inline |
◆ operator()() [2/2]
|
inline |
FIXME: use the future-predicting method to move the window
FIXME: use the future-predicting method to move the window
TODO: remove the sp_delta and use sp_compute directly
TODO: move some fmha_alu1() code here if necessary
Note: The compiler keeps moving the following instructions elsewhere because 'l' is first consumed later. To anchor them here, we rewrite the final addition in inline assembly to create a dependency, forcing the dependent instructions to be emitted at this point.
Note: The compiler keeps sinking the conversion instructions because the result 'p' is only consumed later. To anchor them here, we rewrite the cast_tile() call as inline assembly, forcing the conversions to be emitted at this point.
Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly can interfere with the behavior of sched_group_barrier(), so ending the phase here avoids unintended reordering.
NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call should be placed at the end of a phase.
TODO: find better way to map fmha_alu(0,96) call
◆ s_waitcnt()
|
inlinestaticconstexpr |
◆ s_waitcnt_lgkmcnt()
|
inlinestaticconstexpr |
◆ s_waitcnt_vmcnt()
|
inlinestaticconstexpr |
Member Data Documentation
◆ kAlignmentK
|
staticconstexpr |
◆ kAlignmentO
|
staticconstexpr |
◆ kAlignmentQ
|
staticconstexpr |
◆ kAlignmentV
|
staticconstexpr |
◆ kBlockPerCu
|
staticconstexpr |
◆ kBlockSize
|
staticconstexpr |
◆ kIsGroupMode
|
staticconstexpr |
◆ kK0
|
staticconstexpr |
◆ kK1
|
staticconstexpr |
◆ kM0
|
staticconstexpr |
◆ kN0
|
staticconstexpr |
◆ kN1
|
staticconstexpr |
◆ kPadHeadDimQ
|
staticconstexpr |
◆ kPadHeadDimV
|
staticconstexpr |
◆ kPadSeqLenK
|
staticconstexpr |
◆ kPadSeqLenQ
|
staticconstexpr |
◆ kQKHeaddim
|
staticconstexpr |
◆ kStoreLSE
|
staticconstexpr |
◆ kSubQKHeaddim
|
staticconstexpr |
The documentation for this struct was generated from the following file: