fmha_fwd_splitkv_combine_kernel.hpp Source File#
fmha_fwd_splitkv_combine_kernel.hpp
Go to the documentation of this file.
53 _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
#define _TS_
#define _SS_
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fmha_fwd_splitkv_combine_kernel.hpp:113
ck_tile::index_t batch_stride_lse_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:114
ck_tile::index_t batch_stride_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:116
ck_tile::index_t batch_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:115
Definition fmha_fwd_splitkv_combine_kernel.hpp:76
ck_tile::index_t row_stride_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:87
ck_tile::index_t nhead_stride_lse_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:89
ck_tile::index_t nhead_stride_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:91
ck_tile::index_t nhead_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:90
ck_tile::index_t hdim_v
Definition fmha_fwd_splitkv_combine_kernel.hpp:83
ck_tile::index_t row_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:86
void * o_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:79
ck_tile::index_t num_splits
Definition fmha_fwd_splitkv_combine_kernel.hpp:84
const void * lse_acc_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:77
ck_tile::index_t seqlen_q
Definition fmha_fwd_splitkv_combine_kernel.hpp:82
ck_tile::index_t split_stride_lse_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:93
ck_tile::index_t batch
Definition fmha_fwd_splitkv_combine_kernel.hpp:81
const void * o_acc_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:78
ck_tile::index_t split_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:94
Definition fmha_fwd_splitkv_combine_kernel.hpp:98
ck_tile::index_t batch_stride_lse
Definition fmha_fwd_splitkv_combine_kernel.hpp:101
ck_tile::index_t nhead_stride_lse
Definition fmha_fwd_splitkv_combine_kernel.hpp:100
void * lse_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:99
Definition fmha_fwd_splitkv_combine_kernel.hpp:69
Definition fmha_fwd_splitkv_combine_kernel.hpp:105
float scale_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:106
Definition fmha_fwd_splitkv_combine_kernel.hpp:123
const int32_t * seqstart_q_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:124
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:35
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:37
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:34
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:36
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:33
Definition fmha_fwd_splitkv_combine_kernel.hpp:32
Definition fmha_fwd_splitkv_combine_kernel.hpp:10
static constexpr bool kStoreLSE
Definition fmha_fwd_splitkv_combine_kernel.hpp:28
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_splitkv_combine_kernel.hpp:23
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_splitkv_combine_kernel.hpp:288
static constexpr index_t kBlockPerCuInput
Definition fmha_fwd_splitkv_combine_kernel.hpp:19
static constexpr bool kIsGroupMode
Definition fmha_fwd_splitkv_combine_kernel.hpp:25
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_splitkv_combine_kernel.hpp:271
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition fmha_fwd_splitkv_combine_kernel.hpp:131
static constexpr index_t kNumWarps
Definition fmha_fwd_splitkv_combine_kernel.hpp:14
remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_splitkv_combine_kernel.hpp:21
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_splitkv_combine_kernel.hpp:26
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_fwd_splitkv_combine_kernel.hpp:252
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_splitkv_combine_kernel.hpp:283
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition fmha_fwd_splitkv_combine_kernel.hpp:22
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition fmha_fwd_splitkv_combine_kernel.hpp:189
static constexpr bool kPadHeadDimV
Definition fmha_fwd_splitkv_combine_kernel.hpp:27
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_splitkv_combine_kernel.hpp:40
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_splitkv_combine_kernel.hpp:12
static constexpr bool kDoFp8StaticQuant
Definition fmha_fwd_splitkv_combine_kernel.hpp:29
static constexpr index_t kBlockSize
Definition fmha_fwd_splitkv_combine_kernel.hpp:15
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition fmha_fwd_splitkv_combine_kernel.hpp:238
static constexpr index_t kBlockPerCu
Definition fmha_fwd_splitkv_combine_kernel.hpp:16
remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_splitkv_combine_kernel.hpp:11
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition fmha_fwd_splitkv_combine_kernel.hpp:127
Definition tile/core/utility/functional.hpp:86
Definition unary_element_function.hpp:56
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49