fmha_fwd_splitkv_combine_kernel.hpp Source File

fmha_fwd_splitkv_combine_kernel.hpp Source File#

Composable Kernel: fmha_fwd_splitkv_combine_kernel.hpp Source File
fmha_fwd_splitkv_combine_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6namespace ck_tile {
7
8template <typename FmhaPipeline_, typename EpiloguePipeline_>
10{
13
14 static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
15 static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
16 static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
17
18 static_assert(kBlockPerCu > 0);
19 static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
20
24
25 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
26 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
27 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
28 static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
29 static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
30
31 // clang-format off
32 template <typename T> struct t2s;
33 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
34 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
35 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
36 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
37 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
38 // clang-format on
39
40 CK_TILE_HOST static std::string GetName()
41 {
42 // sync with generate.py
43 // clang-format off
44
45 #define _SS_ std::string
46 #define _TS_ std::to_string
47 auto pn = [&] () {
48 std::string n;
49 if (kPadSeqLenQ) n += "s";
50 if (kPadHeadDimV) n += "dv";
51 return n.empty() ? n : std::string("p") + n; }();
52 return
53 _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
54 "_" + (kIsGroupMode ? "group" : "batch") + "_"
55 "b" + _TS_(FmhaPipeline::kN1) + "_" +
56 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
57 _SS_(FmhaPipeline::name) +
58 (pn.empty() ? "_npad" : "_" + pn) +
59 (kStoreLSE ? "_lse" : "_nlse" ) +
60 (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
61 #undef _SS_
62 #undef _TS_
63 // clang-format on
64 }
65
66 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
67 // arg
69 {
70 };
71
72 // kargs use aggregate initializer, so no constructor will provided
73 // use inheritance to minimize karg size
74 // user need to use MakeKargs() function to create kargs.
96
103
105 {
106 float scale_o;
107 };
108
110 : CommonKargs,
111 std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
112 std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
113 {
117 };
118
120 : CommonKargs,
121 std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
122 std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
123 {
125 };
126
127 using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
128
129 template <bool Cond = !kIsGroupMode>
130 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
131 MakeKargs(const void* lse_acc_ptr,
132 const void* o_acc_ptr,
133 void* lse_ptr,
134 void* o_ptr,
135 ck_tile::index_t batch,
136 ck_tile::index_t seqlen_q,
137 ck_tile::index_t hdim_v,
138 ck_tile::index_t num_splits,
139 float scale_o,
140 ck_tile::index_t row_stride_o_acc,
141 ck_tile::index_t row_stride_o,
142 ck_tile::index_t nhead_stride_lse_acc,
143 ck_tile::index_t nhead_stride_o_acc,
144 ck_tile::index_t nhead_stride_lse,
145 ck_tile::index_t nhead_stride_o,
146 ck_tile::index_t batch_stride_lse_acc,
147 ck_tile::index_t batch_stride_o_acc,
148 ck_tile::index_t batch_stride_lse,
149 ck_tile::index_t batch_stride_o,
150 ck_tile::index_t split_stride_lse_acc,
151 ck_tile::index_t split_stride_o_acc)
152 {
153 Kargs kargs{{lse_acc_ptr,
154 o_acc_ptr,
155 o_ptr,
156 batch,
157 seqlen_q,
158 hdim_v,
159 num_splits,
160 row_stride_o_acc,
161 row_stride_o,
162 nhead_stride_lse_acc,
163 nhead_stride_o_acc,
164 nhead_stride_o,
165 split_stride_lse_acc,
166 split_stride_o_acc}, // args for common karg
167 {}, // placeholder for lse
168 {}, // placeholder for fp8_static_quant args
169 batch_stride_lse_acc,
170 batch_stride_o_acc,
171 batch_stride_o};
172
173 if constexpr(kStoreLSE)
174 {
175 kargs.lse_ptr = lse_ptr;
176 kargs.nhead_stride_lse = nhead_stride_lse;
177 kargs.batch_stride_lse = batch_stride_lse;
178 }
179 if constexpr(kDoFp8StaticQuant)
180 {
181 kargs.scale_o = scale_o;
182 }
183
184 return kargs;
185 }
186
187 template <bool Cond = kIsGroupMode>
188 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
189 MakeKargs(const void* lse_acc_ptr,
190 const void* o_acc_ptr,
191 void* lse_ptr,
192 void* o_ptr,
193 ck_tile::index_t batch,
194 const void* seqstart_q_ptr,
195 ck_tile::index_t hdim_v,
196 ck_tile::index_t num_splits,
197 float scale_o,
198 ck_tile::index_t row_stride_o_acc,
199 ck_tile::index_t row_stride_o,
200 ck_tile::index_t nhead_stride_lse_acc,
201 ck_tile::index_t nhead_stride_o_acc,
202 ck_tile::index_t nhead_stride_lse,
203 ck_tile::index_t nhead_stride_o,
204 ck_tile::index_t split_stride_lse_acc,
205 ck_tile::index_t split_stride_o_acc)
206 {
207 Kargs kargs{{lse_acc_ptr,
208 o_acc_ptr,
209 o_ptr,
210 batch,
211 -1, // seqlen will be updated by another pointer
212 hdim_v,
213 num_splits,
214 row_stride_o_acc,
215 row_stride_o,
216 nhead_stride_lse_acc,
217 nhead_stride_o_acc,
218 nhead_stride_o,
219 split_stride_lse_acc,
220 split_stride_o_acc}, // args for common karg
221 {}, // placeholder for lse
222 {}, // placeholder for fp8_static_quant args
223 reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
224
225 if constexpr(kStoreLSE)
226 {
227 kargs.lse_ptr = lse_ptr;
228 kargs.nhead_stride_lse = nhead_stride_lse;
229 }
230 if constexpr(kDoFp8StaticQuant)
231 {
232 kargs.scale_o = scale_o;
233 }
234
235 return kargs;
236 }
237
238 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
239 ck_tile::index_t nhead,
240 ck_tile::index_t max_seqlen_q,
241 ck_tile::index_t hdim_v)
242 {
243 // Recalculate kM0 = get_warp_size() / NThreads on host
244 const index_t m0 = (is_wave32() ? 32 : 64) / FmhaPipeline::Problem::NThreads;
245 // TODO: this may need tuning
246 return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, m0) *
247 ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
248 nhead,
249 batch_size);
250 }
251
252 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
253 {
254 const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
255
256 const index_t i_block = blockIdx.x;
257 const index_t i_nhead = blockIdx.y;
258 const index_t i_batch = blockIdx.z;
259
260 const auto f = [](index_t dividend, index_t divisor) {
261 index_t quotient = dividend / divisor;
262 index_t modulus = dividend - quotient * divisor;
263 return ck_tile::make_tuple(quotient, modulus);
264 };
265
266 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
267
268 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
269 }
270
272 {
273 if(is_wave32())
274 {
275 return dim3(kBlockSize / 2);
276 }
277 else
278 {
279 return dim3(kBlockSize);
280 }
281 }
282
284 {
285 return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
286 }
287
289 {
290 // allocate LDS
291 __shared__ char smem_ptr[GetSmemSize()];
292
293 // divide problem
294 const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
295
296 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
297 const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
298
299 long_index_t batch_offset_lse_acc = 0;
300 long_index_t batch_offset_o_acc = 0;
301 long_index_t batch_offset_lse = 0;
302 long_index_t batch_offset_o = 0;
303
304 if constexpr(kIsGroupMode)
305 {
306 // get starting offset for each batch
307 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
308
309 batch_offset_lse_acc = query_start;
310 batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
311
312 if constexpr(kStoreLSE)
313 {
314 batch_offset_lse = query_start;
315 }
316
317 batch_offset_o = query_start * kargs.row_stride_o;
318
319 // get real # queries & # keys under group mode
320 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
321 kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
322
323 // # of required blocks is different in each groups, terminate unnecessary blocks
324 // earlier
325 if(kargs.seqlen_q <= i_m0)
326 {
327 return;
328 }
329 }
330 else
331 {
332 batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
333 batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
334
335 if constexpr(kStoreLSE)
336 {
337 batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
338 }
339
340 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
341 }
342
343 // for simplicity, batch stride we just modify the pointer
344 const LSEDataType* lse_acc_ptr =
345 reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
346 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
347 const OaccDataType* o_acc_ptr =
348 reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
349 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
350 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
351 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
352 batch_offset_o;
353
354 // LSEacc/Oacc DRAM and DRAM windows
355 const auto lse_acc_dram = [&]() {
356 const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
357 lse_acc_ptr,
358 make_tuple(kargs.num_splits, kargs.seqlen_q),
359 make_tuple(kargs.split_stride_lse_acc, number<1>{}),
361 number<1>{});
362
363 return pad_tensor_view(
364 lse_acc_dram_naive,
367 }();
368
369 auto o_acc_dram = [&]() {
370 const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
371 o_acc_ptr,
372 make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
373 make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, number<1>{}),
375 number<1>{});
376
377 // read kNumWarps * (kM0, kN1) o_acc tiles simultaneously by kNumWarps warps
378 const auto o_acc_dram_view = pad_tensor_view(
379 o_acc_dram_naive,
383
384 const index_t padded_num_splits =
385 o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}];
386 const index_t padded_seqlen_q =
387 o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
388 const index_t padded_hdim_v =
389 o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
390
391 const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0);
392
393 // transform tensor view by following steps, given shape: (padded_num_splits,
394 // padded_seqlen_q, padded_hdim_v)
395 // 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
396 // 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
397 // 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
398 auto transposed = transform_tensor_view(
399 o_acc_dram_view,
400 make_tuple(make_pass_through_transform(padded_num_splits),
401 make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)),
402 make_pass_through_transform(padded_hdim_v)),
405
407 transposed,
409 make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)),
410 make_pass_through_transform(padded_hdim_v)),
413 }();
414
415 auto lse_acc_dram_window = make_tile_window(
416 lse_acc_dram,
418 {0, i_m0});
419
420 const index_t padded_num_splits =
421 integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps;
422
423 auto o_acc_dram_window = make_tile_window(
424 o_acc_dram,
426 {i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1});
427
428 // LSE DRAM window
429 auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
430 constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
431 if constexpr(kStoreLSE)
432 {
433 LSEDataType* lse_ptr =
434 reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
435 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
436
437 const auto lse_dram = [&]() {
438 const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
439 lse_ptr,
440 make_tuple(kargs.seqlen_q),
441 make_tuple(1),
443 number<1>{});
444
445 return pad_tensor_view(
446 lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
447 }();
448
449 return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
450 }
451 else
452 {
453 return make_null_tile_window(lse_dram_window_lengths);
454 }
455 }();
456
457 auto o_acc_tile = [&]() {
458 if constexpr(kDoFp8StaticQuant)
459 {
460 return FmhaPipeline{}(
461 lse_acc_dram_window,
462 o_acc_dram_window,
463 lse_dram_window,
464 identity{}, // lse_element_func
465 composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
466 kargs.num_splits,
467 smem_ptr);
468 }
469 else
470 {
471 return FmhaPipeline{}(lse_acc_dram_window,
472 o_acc_dram_window,
473 lse_dram_window,
474 kargs.num_splits,
475 smem_ptr);
476 }
477 }();
478
479 // O DRAM and DRAM window
480 auto o_dram = [&]() {
482 o_ptr,
483 make_tuple(kargs.seqlen_q, kargs.hdim_v),
484 make_tuple(kargs.row_stride_o, number<1>{}),
486 number<1>{});
487
488 return pad_tensor_view(
489 o_dram_naive,
492 }();
493
494 auto o_dram_window =
495 make_tile_window(o_dram,
497 {i_m0, i_n1});
498
499 EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
500 }
501};
502
503} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition fmha_fwd_splitkv_combine_kernel.hpp:113
ck_tile::index_t batch_stride_lse_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:114
ck_tile::index_t batch_stride_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:116
ck_tile::index_t batch_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:115
Definition fmha_fwd_splitkv_combine_kernel.hpp:76
ck_tile::index_t row_stride_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:87
ck_tile::index_t nhead_stride_lse_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:89
ck_tile::index_t nhead_stride_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:91
ck_tile::index_t nhead_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:90
ck_tile::index_t hdim_v
Definition fmha_fwd_splitkv_combine_kernel.hpp:83
ck_tile::index_t row_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:86
void * o_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:79
ck_tile::index_t num_splits
Definition fmha_fwd_splitkv_combine_kernel.hpp:84
const void * lse_acc_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:77
ck_tile::index_t seqlen_q
Definition fmha_fwd_splitkv_combine_kernel.hpp:82
ck_tile::index_t split_stride_lse_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:93
ck_tile::index_t batch
Definition fmha_fwd_splitkv_combine_kernel.hpp:81
const void * o_acc_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:78
ck_tile::index_t split_stride_o_acc
Definition fmha_fwd_splitkv_combine_kernel.hpp:94
Definition fmha_fwd_splitkv_combine_kernel.hpp:98
ck_tile::index_t batch_stride_lse
Definition fmha_fwd_splitkv_combine_kernel.hpp:101
ck_tile::index_t nhead_stride_lse
Definition fmha_fwd_splitkv_combine_kernel.hpp:100
void * lse_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:99
Definition fmha_fwd_splitkv_combine_kernel.hpp:69
Definition fmha_fwd_splitkv_combine_kernel.hpp:105
float scale_o
Definition fmha_fwd_splitkv_combine_kernel.hpp:106
Definition fmha_fwd_splitkv_combine_kernel.hpp:123
const int32_t * seqstart_q_ptr
Definition fmha_fwd_splitkv_combine_kernel.hpp:124
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:35
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:37
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:34
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:36
static constexpr const char * name
Definition fmha_fwd_splitkv_combine_kernel.hpp:33
Definition fmha_fwd_splitkv_combine_kernel.hpp:32
Definition fmha_fwd_splitkv_combine_kernel.hpp:10
static constexpr bool kStoreLSE
Definition fmha_fwd_splitkv_combine_kernel.hpp:28
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_splitkv_combine_kernel.hpp:23
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_splitkv_combine_kernel.hpp:288
static constexpr index_t kBlockPerCuInput
Definition fmha_fwd_splitkv_combine_kernel.hpp:19
static constexpr bool kIsGroupMode
Definition fmha_fwd_splitkv_combine_kernel.hpp:25
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_splitkv_combine_kernel.hpp:271
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition fmha_fwd_splitkv_combine_kernel.hpp:131
static constexpr index_t kNumWarps
Definition fmha_fwd_splitkv_combine_kernel.hpp:14
remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_splitkv_combine_kernel.hpp:21
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_splitkv_combine_kernel.hpp:26
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_fwd_splitkv_combine_kernel.hpp:252
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_splitkv_combine_kernel.hpp:283
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition fmha_fwd_splitkv_combine_kernel.hpp:22
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition fmha_fwd_splitkv_combine_kernel.hpp:189
static constexpr bool kPadHeadDimV
Definition fmha_fwd_splitkv_combine_kernel.hpp:27
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_splitkv_combine_kernel.hpp:40
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_splitkv_combine_kernel.hpp:12
static constexpr bool kDoFp8StaticQuant
Definition fmha_fwd_splitkv_combine_kernel.hpp:29
static constexpr index_t kBlockSize
Definition fmha_fwd_splitkv_combine_kernel.hpp:15
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition fmha_fwd_splitkv_combine_kernel.hpp:238
static constexpr index_t kBlockPerCu
Definition fmha_fwd_splitkv_combine_kernel.hpp:16
remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_splitkv_combine_kernel.hpp:11
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition fmha_fwd_splitkv_combine_kernel.hpp:127
Definition tile/core/utility/functional.hpp:86
Definition unary_element_function.hpp:56
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49