block_fmha_fwd_v3_pipeline.hpp Source File#
block_fmha_fwd_v3_pipeline.hpp
Go to the documentation of this file.
289 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
387 CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
1230 CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition block_fmha_fwd_v3_pipeline.hpp:230
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition block_fmha_fwd_v3_pipeline.hpp:190
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition block_fmha_fwd_v3_pipeline.hpp:212
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition block_fmha_fwd_v3_pipeline.hpp:203
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition block_fmha_fwd_v3_pipeline.hpp:221
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition block_fmha_fwd_v3_pipeline.hpp:239
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_fmha_fwd_v3_pipeline.hpp:251
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_v3_pipeline.hpp:283
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_fwd_v3_pipeline.hpp:260
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_v3_pipeline.hpp:286
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_fwd_v3_pipeline.hpp:258
static constexpr ck_tile::index_t kQKHeaddim
Definition block_fmha_fwd_v3_pipeline.hpp:277
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_fwd_v3_pipeline.hpp:268
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_v3_pipeline.hpp:261
static constexpr ck_tile::index_t kAlignmentV
Definition block_fmha_fwd_v3_pipeline.hpp:295
static constexpr ck_tile::index_t kN0
Definition block_fmha_fwd_v3_pipeline.hpp:273
static constexpr ck_tile::index_t kK0
Definition block_fmha_fwd_v3_pipeline.hpp:274
ck_tile::remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_v3_pipeline.hpp:252
static CK_TILE_DEVICE constexpr void s_waitcnt()
Definition block_fmha_fwd_v3_pipeline.hpp:355
static CK_TILE_DEVICE constexpr void s_waitcnt_lgkmcnt()
Definition block_fmha_fwd_v3_pipeline.hpp:371
static CK_TILE_DEVICE constexpr auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition block_fmha_fwd_v3_pipeline.hpp:344
static constexpr ck_tile::index_t kAlignmentK
Definition block_fmha_fwd_v3_pipeline.hpp:293
static constexpr ck_tile::index_t kAlignmentO
Definition block_fmha_fwd_v3_pipeline.hpp:298
static CK_TILE_DEVICE constexpr auto MakeSimpleLdsDesc()
Definition block_fmha_fwd_v3_pipeline.hpp:320
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_fwd_v3_pipeline.hpp:254
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_v3_pipeline.hpp:282
static constexpr ck_tile::index_t kSubQKHeaddim
Definition block_fmha_fwd_v3_pipeline.hpp:278
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_v3_pipeline.hpp:259
static CK_TILE_DEVICE constexpr auto MakeSimpleLdsDesc1D()
Definition block_fmha_fwd_v3_pipeline.hpp:334
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_v3_pipeline.hpp:253
static constexpr bool kPadSeqLenK
Definition block_fmha_fwd_v3_pipeline.hpp:284
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_fwd_v3_pipeline.hpp:263
static constexpr ck_tile::index_t kN1
Definition block_fmha_fwd_v3_pipeline.hpp:275
static constexpr ck_tile::index_t kK1
Definition block_fmha_fwd_v3_pipeline.hpp:276
static constexpr bool kPadHeadDimQ
Definition block_fmha_fwd_v3_pipeline.hpp:285
static constexpr bool kStoreLSE
Definition block_fmha_fwd_v3_pipeline.hpp:287
static constexpr ck_tile::index_t kM0
Definition block_fmha_fwd_v3_pipeline.hpp:272
static constexpr ck_tile::index_t kBlockPerCu
Definition block_fmha_fwd_v3_pipeline.hpp:301
static constexpr ck_tile::index_t kAlignmentQ
Definition block_fmha_fwd_v3_pipeline.hpp:291
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition block_fmha_fwd_v3_pipeline.hpp:387
static CK_TILE_DEVICE constexpr void s_waitcnt_vmcnt()
Definition block_fmha_fwd_v3_pipeline.hpp:365
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_fwd_v3_pipeline.hpp:256
static constexpr ck_tile::index_t kBlockSize
Definition block_fmha_fwd_v3_pipeline.hpp:270
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_v3_pipeline.hpp:310
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition block_fmha_fwd_v3_pipeline.hpp:1230
static CK_TILE_DEVICE constexpr void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition block_fmha_fwd_v3_pipeline.hpp:119
static CK_TILE_DEVICE constexpr void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition block_fmha_fwd_v3_pipeline.hpp:45
Definition block_fmha_fwd_v3_pipeline.hpp:39
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41