fmha_fwd_pagedkv_kernel.hpp Source File

fmha_fwd_pagedkv_kernel.hpp Source File#

Composable Kernel: fmha_fwd_pagedkv_kernel.hpp Source File
fmha_fwd_pagedkv_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11#include <string>
12#include <type_traits>
13#include <utility>
14#include <variant>
15
16// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21
22namespace ck_tile {
23
24// TODO: This class is a variant of the existing FmhaFwdSplitKVKernel pipeline.
25// Refactoring to extract shared logic is recommended as future work.
26template <typename FmhaPipeline_, typename EpiloguePipeline_>
28{
31 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33
34 static_assert(kBlockPerCu > 0);
35 static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36
44
46
47 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49 static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50 static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52 static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53 static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54 static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55 static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
56 static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
57 static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
58
61 static constexpr bool kHasMask = FmhaMask::IsMasking;
62
63 static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
64
65 // clang-format off
66 template <typename T> struct t2s;
67 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
68 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
69 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
70 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
71 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
72 // clang-format on
73
74 CK_TILE_HOST static std::string GetName()
75 {
76 // sync with generate.py
77 // clang-format off
78 using bfs = typename FmhaPipeline::BlockFmhaShape;
79 using g0br = typename bfs::Gemm0BlockWarps;
80 using g1br = typename bfs::Gemm1BlockWarps;
81 using g0wt = typename bfs::Gemm0WarpTile;
82 using g1wt = typename bfs::Gemm1WarpTile;
83 #define _SS_ std::string
84 #define _TS_ std::to_string
85 auto pn = [&] () {
86 std::string n;
87 if (kPadSeqLenQ) n += "s";
88 if (kPadSeqLenK) n += "sk";
89 if (kPadHeadDimQ) n += "d";
90 if (kPadHeadDimV) n += "dv";
91 return n.empty() ? n : std::string("p") + n; }();
92 return
93 _SS_("fmha_fwd_pagedkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
94 "_" + (kIsGroupMode ? "group" : "batch") + "_"
95 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
96 _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
97 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
98 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
99 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
100 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
101 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
102 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
104 (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
105 #undef _SS_
106 #undef _TS_
107 // clang-format on
108 }
109
110 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
111 // arg
113 {
114 };
115
116 // kargs use aggregate initializer, so no constructor will provided
117 // use inheritance to minimize karg size
118 // user need to use MakeKargs() function to create kargs.
147
149 {
151
152 void init_logits_soft_cap(float logits_soft_cap_)
153 {
154 if(0 < logits_soft_cap_)
155 {
156 logits_soft_cap = logits_soft_cap_;
158 }
159 else
160 {
161 logits_soft_cap = 0.f;
163 }
164 }
165
168 };
169
176
181
183 {
184 // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
185 const void* alibi_slope_ptr;
186 ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
187 };
188
190 {
191 // ck_tile::index_t window_size_left, window_size_right;
194 };
195
197 {
198 float scale_p;
199 float scale_o;
200 };
201
208
213
220
225
227 {
229 };
230
233 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
234 FmhaFwdBatchModeBiasKargs,
235 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
236 FmhaFwdAlibiKargs,
237 FmhaFwdEmptyKargs<0>>>,
238 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
239 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
240 std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
241 std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
242 std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
243 {
245
247 ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
248 // single kcache page-block
249 ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
250 // single vcache page-block
252 };
253
256 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
257 FmhaFwdCommonBiasKargs,
258 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
259 FmhaFwdAlibiKargs,
260 FmhaFwdEmptyKargs<0>>>,
261 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
262 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
263 std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
264 std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>,
265 std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, FmhaFwdEmptyKargs<5>>,
266 std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
267 {
271
272 ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
273 // for single kcache page-block
274 ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
275 // for single vcache page-block
276 };
277
278 using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
279
286
287 template <bool Cond = !kIsGroupMode>
288 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
289 MakeKargsImpl(const void* q_ptr,
290 const void* k_ptr,
291 const void* v_ptr,
292 const void* bias_ptr,
293 void* lse_ptr,
294 void* o_ptr,
295 ck_tile::index_t seqlen_q,
296 ck_tile::index_t seqlen_k,
297 const void* seqlen_k_ptr, // only used for (paged-) kvcache
298 ck_tile::index_t hdim_q,
299 ck_tile::index_t hdim_v,
300 ck_tile::index_t num_head_q,
301 ck_tile::index_t nhead_ratio_qk,
302 const void* block_table_ptr,
303 ck_tile::index_t batch_stride_block_table,
304 ck_tile::index_t page_block_size,
305 const void* cache_batch_idx,
306 float scale_s,
307 float scale_p,
308 float scale_o,
309 float logits_soft_cap,
310 ck_tile::index_t stride_q,
311 ck_tile::index_t stride_k,
312 ck_tile::index_t stride_v,
313 ck_tile::index_t stride_bias,
314 ck_tile::index_t stride_o,
315 ck_tile::index_t nhead_stride_q,
316 ck_tile::index_t nhead_stride_k,
317 ck_tile::index_t nhead_stride_v,
318 ck_tile::index_t nhead_stride_bias,
319 ck_tile::index_t nhead_stride_lse,
320 ck_tile::index_t nhead_stride_o,
321 ck_tile::index_t batch_stride_q,
322 ck_tile::index_t batch_stride_k,
323 ck_tile::index_t batch_stride_v,
324 ck_tile::index_t batch_stride_bias,
325 ck_tile::index_t batch_stride_lse,
326 ck_tile::index_t batch_stride_o,
327 ck_tile::index_t window_size_left,
328 ck_tile::index_t window_size_right,
329 ck_tile::index_t mask_type)
330 {
331 Kargs kargs{{q_ptr,
332 k_ptr,
333 v_ptr,
334 o_ptr,
335 seqlen_q,
336 seqlen_k,
337 hdim_q,
338 hdim_v,
339 num_head_q,
340 nhead_ratio_qk,
341#if CK_TILE_FMHA_FWD_FAST_EXP2
342 static_cast<float>(scale_s * ck_tile::log2e_v<>),
343#else
344 scale_s,
345#endif
346 stride_q,
347 stride_k,
348 stride_v,
349 stride_o,
350 nhead_stride_q,
351 nhead_stride_k,
352 nhead_stride_v,
353 nhead_stride_o}, // args for common karg
354 {}, // placeholder for bias
355 {}, // placeholder for mask
356 {}, // placeholder for lse
357 {}, // placeholder for fp8_static_quant args
358 {}, // placeholder for pagedkv
359 {}, // placeholder for logits_soft_cap
360 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
361 batch_stride_q,
362 batch_stride_k,
363 batch_stride_v,
364 batch_stride_o};
365
367 {
368 kargs.bias_ptr = bias_ptr;
369 kargs.stride_bias = stride_bias;
370 kargs.nhead_stride_bias = nhead_stride_bias;
371 kargs.batch_stride_bias = batch_stride_bias;
372 }
373 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
374 {
375 kargs.alibi_slope_ptr = bias_ptr;
376 kargs.alibi_slope_stride = stride_bias;
377 }
378 if constexpr(kHasMask)
379 {
380 kargs.window_size_left = window_size_left;
381 kargs.window_size_right = window_size_right;
382 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
383 }
384 if constexpr(kStoreLSE)
385 {
386 kargs.lse_ptr = lse_ptr;
387 kargs.nhead_stride_lse = nhead_stride_lse;
388 kargs.batch_stride_lse = batch_stride_lse;
389 }
390 if constexpr(kDoFp8StaticQuant)
391 {
392 kargs.scale_p = scale_p;
393 kargs.scale_o = scale_o;
394 }
395 if constexpr(kIsPagedKV)
396 {
397 kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
398 kargs.batch_stride_block_table = batch_stride_block_table;
399 kargs.page_block_size = page_block_size;
400 }
401 else
402 {
403 kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
404 }
405 if constexpr(kHasLogitsSoftCap)
406 {
407 kargs.init_logits_soft_cap(logits_soft_cap);
408 }
409
410 return kargs;
411 }
412
413 // std::variant<> can't take in a list initializer, overload for backward compatibility
414 template <bool Cond = !kIsGroupMode>
415 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
416 MakeKargs(const void* q_ptr,
417 const void* k_ptr,
418 const void* v_ptr,
419 const void* bias_ptr,
420 void* lse_ptr,
421 void* o_ptr,
422 ck_tile::index_t seqlen_q,
423 ck_tile::index_t seqlen_k,
424 const void* seqlen_k_ptr, // only used for (paged-) kvcache
425 ck_tile::index_t hdim_q,
426 ck_tile::index_t hdim_v,
427 ck_tile::index_t num_head_q,
428 ck_tile::index_t nhead_ratio_qk,
429 const void* block_table_ptr,
430 ck_tile::index_t batch_stride_block_table,
431 ck_tile::index_t page_block_size,
432 const void* cache_batch_idx,
433 float scale_s,
434 float scale_p,
435 float scale_o,
436 float logits_soft_cap,
437 ck_tile::index_t stride_q,
438 ck_tile::index_t stride_k,
439 ck_tile::index_t stride_v,
440 ck_tile::index_t stride_bias,
441 ck_tile::index_t stride_o,
442 ck_tile::index_t nhead_stride_q,
443 ck_tile::index_t nhead_stride_k,
444 ck_tile::index_t nhead_stride_v,
445 ck_tile::index_t nhead_stride_bias,
446 ck_tile::index_t nhead_stride_lse,
447 ck_tile::index_t nhead_stride_o,
448 ck_tile::index_t batch_stride_q,
449 ck_tile::index_t batch_stride_k,
450 ck_tile::index_t batch_stride_v,
451 ck_tile::index_t batch_stride_bias,
452 ck_tile::index_t batch_stride_lse,
453 ck_tile::index_t batch_stride_o,
454 ck_tile::index_t window_size_left,
455 ck_tile::index_t window_size_right,
456 ck_tile::index_t mask_type)
457 {
458 return MakeKargsImpl(q_ptr,
459 k_ptr,
460 v_ptr,
461 bias_ptr,
462 lse_ptr,
463 o_ptr,
464 seqlen_q,
465 seqlen_k,
466 seqlen_k_ptr,
467 hdim_q,
468 hdim_v,
469 num_head_q,
470 nhead_ratio_qk,
471 block_table_ptr,
472 batch_stride_block_table,
473 page_block_size,
474 cache_batch_idx,
475 scale_s,
476 scale_p,
477 scale_o,
478 logits_soft_cap,
479 stride_q,
480 stride_k,
481 stride_v,
482 stride_bias,
483 stride_o,
484 nhead_stride_q,
485 nhead_stride_k,
486 nhead_stride_v,
487 nhead_stride_bias,
488 nhead_stride_lse,
489 nhead_stride_o,
490 batch_stride_q,
491 batch_stride_k,
492 batch_stride_v,
493 batch_stride_bias,
494 batch_stride_lse,
495 batch_stride_o,
496 window_size_left,
497 window_size_right,
498 mask_type);
499 }
500
501 template <bool Cond = kIsGroupMode>
502 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
503 MakeKargsImpl(const void* q_ptr,
504 const void* k_ptr,
505 const void* v_ptr,
506 const void* bias_ptr,
507 void* lse_ptr,
508 void* o_ptr,
509 const void* seqstart_q_ptr,
510 const void* seqstart_k_ptr,
511 const void* seqlen_k_ptr,
512 ck_tile::index_t hdim_q,
513 ck_tile::index_t hdim_v,
514 ck_tile::index_t num_head_q,
515 ck_tile::index_t nhead_ratio_qk,
516 const void* block_table_ptr,
517 ck_tile::index_t batch_stride_block_table,
518 ck_tile::index_t page_block_size,
519 bool is_gappy,
520 float scale_s,
521 float scale_p,
522 float scale_o,
523 float logits_soft_cap,
524 ck_tile::index_t stride_q,
525 ck_tile::index_t stride_k,
526 ck_tile::index_t stride_v,
527 ck_tile::index_t stride_bias,
528 ck_tile::index_t stride_o,
529 ck_tile::index_t nhead_stride_q,
530 ck_tile::index_t nhead_stride_k,
531 ck_tile::index_t nhead_stride_v,
532 ck_tile::index_t nhead_stride_bias,
533 ck_tile::index_t nhead_stride_lse,
534 ck_tile::index_t nhead_stride_o,
535 ck_tile::index_t batch_stride_k, // only used for paged-kvcache
536 ck_tile::index_t batch_stride_v, // only used for paged-kvcache
537 ck_tile::index_t window_size_left,
538 ck_tile::index_t window_size_right,
539 ck_tile::index_t mask_type,
540 ck_tile::index_t min_seqlen_q)
541 {
542 Kargs kargs{{q_ptr,
543 k_ptr,
544 v_ptr,
545 o_ptr,
546 -1, // seqlen will be updated by another pointer
547 -1, //
548 hdim_q,
549 hdim_v,
550 num_head_q,
551 nhead_ratio_qk,
552#if CK_TILE_FMHA_FWD_FAST_EXP2
553 static_cast<float>(scale_s * ck_tile::log2e_v<>),
554#else
555 scale_s,
556#endif
557 stride_q,
558 stride_k,
559 stride_v,
560 stride_o,
561 nhead_stride_q,
562 nhead_stride_k,
563 nhead_stride_v,
564 nhead_stride_o}, // args for common karg
565 {}, // placeholder for bias
566 {}, // placeholder for mask
567 {}, // placeholder for lse
568 {}, // placeholder for fp8_static_quant args
569 {}, // placeholder for logits_soft_cap
570 {}, // placeholder for pagdkv
571 {}, // placeholder for min_seqlen_q
572 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
573 reinterpret_cast<const int32_t*>(seqstart_k_ptr),
574 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
575 batch_stride_k,
576 batch_stride_v};
577
579 {
580 kargs.bias_ptr = bias_ptr;
581 kargs.stride_bias = stride_bias;
582 kargs.nhead_stride_bias = nhead_stride_bias;
583 }
584 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
585 {
586 kargs.alibi_slope_ptr = bias_ptr;
587 kargs.alibi_slope_stride = stride_bias;
588 }
589 if constexpr(kHasMask)
590 {
591 kargs.window_size_left = window_size_left;
592 kargs.window_size_right = window_size_right;
593 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
594 }
595 if constexpr(kStoreLSE)
596 {
597 kargs.lse_ptr = lse_ptr;
598 kargs.nhead_stride_lse = nhead_stride_lse;
599 }
600 if constexpr(kDoFp8StaticQuant)
601 {
602 kargs.scale_p = scale_p;
603 kargs.scale_o = scale_o;
604 }
605 if constexpr(kHasLogitsSoftCap)
606 {
607 kargs.init_logits_soft_cap(logits_soft_cap);
608 }
609 if constexpr(kIsPagedKV)
610 {
611 kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
612 kargs.batch_stride_block_table = batch_stride_block_table;
613 kargs.page_block_size = page_block_size;
614 kargs.is_gappy = is_gappy;
615 }
616 if constexpr(kSkipMinSeqlenQ)
617 {
618 kargs.min_seqlen_q = min_seqlen_q;
619 }
620
621 return kargs;
622 }
623
624 // std::variant<> can't take in a list initializer, overload for backward compatibility
625 template <bool Cond = kIsGroupMode>
626 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
627 MakeKargs(const void* q_ptr,
628 const void* k_ptr,
629 const void* v_ptr,
630 const void* bias_ptr,
631 void* lse_ptr,
632 void* o_ptr,
633 const void* seqstart_q_ptr,
634 const void* seqstart_k_ptr,
635 const void* seqlen_k_ptr,
636 ck_tile::index_t hdim_q,
637 ck_tile::index_t hdim_v,
638 ck_tile::index_t num_head_q,
639 ck_tile::index_t nhead_ratio_qk,
640 const void* block_table_ptr,
641 ck_tile::index_t batch_stride_block_table,
642 ck_tile::index_t page_block_size,
643 bool is_gappy,
644 float scale_s,
645 float scale_p,
646 float scale_o,
647 float logits_soft_cap,
648 ck_tile::index_t stride_q,
649 ck_tile::index_t stride_k,
650 ck_tile::index_t stride_v,
651 ck_tile::index_t stride_bias,
652 ck_tile::index_t stride_o,
653 ck_tile::index_t nhead_stride_q,
654 ck_tile::index_t nhead_stride_k,
655 ck_tile::index_t nhead_stride_v,
656 ck_tile::index_t nhead_stride_bias,
657 ck_tile::index_t nhead_stride_lse,
658 ck_tile::index_t nhead_stride_o,
659 ck_tile::index_t batch_stride_k, // only used for paged-kvcache
660 ck_tile::index_t batch_stride_v, // only used for paged-kvcache
661 ck_tile::index_t window_size_left,
662 ck_tile::index_t window_size_right,
663 ck_tile::index_t mask_type,
664 ck_tile::index_t min_seqlen_q)
665 {
666 return MakeKargsImpl(q_ptr,
667 k_ptr,
668 v_ptr,
669 bias_ptr,
670 lse_ptr,
671 o_ptr,
672 seqstart_q_ptr,
673 seqstart_k_ptr,
674 seqlen_k_ptr,
675 hdim_q,
676 hdim_v,
677 num_head_q,
678 nhead_ratio_qk,
679 block_table_ptr,
680 batch_stride_block_table,
681 page_block_size,
682 is_gappy,
683 scale_s,
684 scale_p,
685 scale_o,
686 logits_soft_cap,
687 stride_q,
688 stride_k,
689 stride_v,
690 stride_bias,
691 stride_o,
692 nhead_stride_q,
693 nhead_stride_k,
694 nhead_stride_v,
695 nhead_stride_bias,
696 nhead_stride_lse,
697 nhead_stride_o,
698 batch_stride_k,
699 batch_stride_v,
700 window_size_left,
701 window_size_right,
702 mask_type,
703 min_seqlen_q);
704 }
705
706 CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
707 {
708 static bool dummy = [&]() {
709 std::cout << std::endl;
710
711 std::cout << " q_ptr: " << kargs.q_ptr << " k_ptr:" << kargs.k_ptr
712 << " v_ptr: " << kargs.v_ptr << " o_ptr:" << kargs.o_ptr
713 << " hdim_q: " << kargs.hdim_q << " hdim_v: " << kargs.hdim_v
714 << " num_head_q:" << kargs.num_head_q
715 << " nhead_ratio_qk: " << kargs.nhead_ratio_qk << " scale_s:" << kargs.scale_s
716 << " stride_q:" << kargs.stride_q << " stride_k:" << kargs.stride_k
717 << " stride_v:" << kargs.stride_v << " stride_o:" << kargs.stride_o
718 << " nhead_stride_q: " << kargs.nhead_stride_q
719 << " nhead_stride_k: " << kargs.nhead_stride_k
720 << " nhead_stride_v:" << kargs.nhead_stride_v
721 << " nhead_stride_o: " << kargs.nhead_stride_o;
722 if constexpr(!kIsGroupMode)
723 {
724 std::cout << " batch_stride_q:" << kargs.batch_stride_q;
725 }
726 std::cout << " batch_stride_k:" << kargs.batch_stride_k
727 << " batch_stride_v:" << kargs.batch_stride_v;
728
729 if constexpr(kIsGroupMode)
730 {
731 if constexpr(kSkipMinSeqlenQ)
732 {
733 std::cout << " min_seqlen_q: " << kargs.min_seqlen_q;
734 }
735
736 std::cout << " seqstart_q_ptr:" << kargs.seqstart_q_ptr
737 << " seqstart_k_ptr: " << kargs.seqstart_k_ptr
738 << " seqlen_k_ptr:" << kargs.seqlen_k_ptr;
739 if(kargs.seqlen_k_ptr != nullptr)
740 {
741 std::cout << "{";
742 for(int i_batch = 0; i_batch < num_batches; i_batch++)
743 std::cout << kargs.seqlen_k_ptr[i_batch] << ",";
744 std::cout << "}";
745 }
746 }
747 if constexpr(kHasMask)
748 {
749 std::cout << " window_size_left: " << kargs.window_size_left
750 << " window_size_right:" << kargs.window_size_right
751 << " mask_type: " << static_cast<int>(kargs.mask_type);
752 }
753
754 if constexpr(kIsPagedKV)
755 {
756 std::cout << " block_table_ptr: " << kargs.block_table_ptr
757 << " batch_stride_block_table:" << kargs.batch_stride_block_table
758 << " page_block_size: " << kargs.page_block_size;
759
760 std::cout << "table value: [";
761 for(int b = 0; b < num_batches; b++)
762 {
763 std::cout << "[ ";
764 for(int i = 0; i < kargs.batch_stride_block_table; i++)
765 {
766 std::cout << kargs.block_table_ptr[b * kargs.batch_stride_block_table + i]
767 << ",";
768 }
769 std::cout << " ]";
770 }
771 std::cout << " ]";
772 }
773 std::cout << std::endl;
774 return true;
775 }();
776 (void)dummy;
777 }
778 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
779 ck_tile::index_t nhead_,
780 ck_tile::index_t seqlen_q_,
781 ck_tile::index_t hdim_v_,
782 bool has_padded_seqlen_k)
783 {
784 // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
785 if(has_padded_seqlen_k)
786 {
787 // TODO: this may need tuning
788 return dim3(nhead_,
789 batch_size_,
790 ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
791 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
792 }
793 else
794 {
795 // TODO: this may need tuning
796 return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
797 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
798 nhead_,
799 batch_size_);
800 }
801 }
802
803 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
804 {
805 bool has_padded_seqlen_k = false;
806
807 if constexpr(kIsGroupMode)
808 has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
809
810 if(has_padded_seqlen_k)
811 {
812 // const index_t num_tile_m0 = seqlen_q / kM0;
813 const index_t num_tile_n1 =
814 ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
815
816 const index_t i_block = blockIdx.z;
817 const index_t i_nhead = blockIdx.x;
818 const index_t i_batch = blockIdx.y;
819
820 const auto f = [](index_t dividend, index_t divisor) {
821 index_t quotient = dividend / divisor;
822 index_t modulus = dividend - quotient * divisor;
823 return ck_tile::make_tuple(quotient, modulus);
824 };
825
826 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
827
828 if constexpr(kHasMask)
829 {
830 // assume that num_tile_n1 is always 1
831 return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
832 }
833 else
834 {
835 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
836 }
837 }
838 else
839 {
840 // const index_t num_tile_m0 = seqlen_q / kM0;
841 const index_t num_tile_n1 =
842 ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
843
844 const index_t i_block = blockIdx.x;
845 const index_t i_nhead = blockIdx.y;
846 const index_t i_batch = blockIdx.z;
847
848 const auto f = [](index_t dividend, index_t divisor) {
849 index_t quotient = dividend / divisor;
850 index_t modulus = dividend - quotient * divisor;
851 return ck_tile::make_tuple(quotient, modulus);
852 };
853
854 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
855
856 if constexpr(kHasMask)
857 {
858 // assume that num_tile_n1 is always 1
859 return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
860 }
861 else
862 {
863 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
864 }
865 }
866 }
867
869 {
870 if(is_wave32())
871 {
872 return dim3(kBlockSize / 2);
873 }
874 else
875 {
876 return dim3(kBlockSize);
877 }
878 }
879
881 {
882 return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
883 }
884
886 {
887 // allocate LDS
888 __shared__ char smem_ptr[GetSmemSize()];
889
890 // divide problem
891 const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
892
893 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
894 const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
895
896 long_index_t batch_offset_q = 0;
897 long_index_t batch_offset_k = 0;
898 long_index_t batch_offset_v = 0;
899 long_index_t batch_offset_bias = 0;
900 long_index_t batch_offset_lse = 0;
901 long_index_t batch_offset_o = 0;
902 index_t kv_l2p_offset = 0;
903
904 if constexpr(kIsGroupMode)
905 {
906 // get starting offset for each batch
907 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
908 const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
909
910 batch_offset_q = query_start * kargs.stride_q;
911 batch_offset_k = key_start * kargs.stride_k;
912 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
913 {
914 batch_offset_v = key_start * kargs.stride_v;
915 }
916 else
917 {
918 batch_offset_v = key_start;
919 }
921 {
922 batch_offset_bias = query_start * kargs.stride_bias;
923 }
924 if constexpr(kStoreLSE)
925 {
926 batch_offset_lse = query_start;
927 }
928
929 batch_offset_o = query_start * kargs.stride_o;
930
931 // get real # queries & # keys under group mode
932 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
933 kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
934
935 if constexpr(kSkipMinSeqlenQ)
936 {
937 if(kargs.seqlen_q <= kargs.min_seqlen_q)
938 {
939 return;
940 }
941 }
942
943 // # of required blocks is different in each groups, terminate unnecessary blocks
944 // earlier
945 if(kargs.seqlen_q <= i_m0)
946 {
947 return;
948 }
949
950 if(kargs.seqlen_k_ptr != nullptr)
951 {
952 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
953 }
954 else
955 {
956 const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
957 kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
958 }
959
960 if constexpr(kIsPagedKV)
961 {
962 if(kargs.is_gappy)
963 {
964 // seqstart_k_ptr has different meaning in this case
965 kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
966 }
967 }
968 }
969 else
970 {
971 const index_t i_cache_batch = [&, i_batch_ = i_batch] {
972 if constexpr(kIsPagedKV)
973 {
974 return i_batch_;
975 }
976 else
977 {
978 return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
979 : i_batch_);
980 }
981 }();
982
983 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
984 batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
985 batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
987 {
988 batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
989 }
990 if constexpr(kStoreLSE)
991 {
992 batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
993 }
994
995 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
996
997 if(kargs.seqlen_k_ptr != nullptr)
998 {
999 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1000 }
1001 }
1002
1003 // for simplicity, batch stride we just modify the pointer
1004 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1005 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1006 batch_offset_q;
1007 const KDataType* k_ptr =
1008 reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1009 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1010 batch_offset_k;
1011 const VDataType* v_ptr =
1012 reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1013 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1014 batch_offset_v;
1015 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1016 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1017 batch_offset_o;
1018
1019 // Q/K/V DRAM and DRAM window
1020 const auto q_dram = [&]() {
1022 q_ptr,
1023 make_tuple(kargs.seqlen_q, kargs.hdim_q),
1024 make_tuple(kargs.stride_q, 1),
1026 number<1>{});
1027 if constexpr(FmhaPipeline::kQLoadOnce)
1028 {
1029 return pad_tensor_view(
1030 q_dram_naive,
1033 }
1034 else
1035 {
1036 return pad_tensor_view(
1037 q_dram_naive,
1040 }
1041 }();
1042
1043 const auto make_k_dram = [&](const KDataType* data, index_t height) {
1045 data, // will update this pointer if using paged-kvcache
1046 make_tuple(height, kargs.hdim_q),
1047 make_tuple(kargs.stride_k, 1),
1049 number<1>{});
1050
1051 return pad_tensor_view(
1052 k_dram_naive,
1055 };
1056 const auto k_dram = [&]() {
1057 if constexpr(kIsPagedKV)
1058 {
1059 return make_k_dram(nullptr, kargs.page_block_size);
1060 }
1061 else
1062 {
1063 return make_k_dram(k_ptr, kargs.seqlen_k);
1064 }
1065 }();
1066
1067 const auto make_v_dram = [&](const VDataType* data, index_t length) {
1068 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1069 {
1071 data, // will update this pointer if using paged-kvcache
1072 make_tuple(length, kargs.hdim_v),
1073 make_tuple(kargs.stride_v, 1),
1075 number<1>{});
1076
1077 const auto v_dram_transposed =
1078 transform_tensor_view(v_dram_naive,
1083
1084 return pad_tensor_view(
1085 v_dram_transposed,
1088 }
1089 else
1090 {
1092 data, // will update this pointer if using paged-kvcache
1093 make_tuple(kargs.hdim_v, length),
1094 make_tuple(kargs.stride_v, 1),
1096 number<1>{});
1097
1098 return pad_tensor_view(
1099 v_dram_naive,
1102 }
1103 };
1104 const auto v_dram = [&]() {
1105 if constexpr(kIsPagedKV)
1106 {
1107 return make_v_dram(nullptr, kargs.page_block_size);
1108 }
1109 else
1110 {
1111 return make_v_dram(v_ptr, kargs.seqlen_k);
1112 }
1113 }();
1114
1115 auto q_dram_window = make_tile_window(
1116 q_dram,
1117 [&]() {
1118 if constexpr(FmhaPipeline::kQLoadOnce)
1121 else
1123 }(),
1124 {i_m0, 0});
1125
1126 auto k_page_block_navigator =
1127 [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1128 if constexpr(kIsPagedKV)
1129 {
1130 const auto* block_indices =
1131 reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1132 i_batch_ * kargs.batch_stride_block_table;
1133 const index_t num_blocks =
1134 integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1135
1136 const long_index_t fixed_offset =
1137 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_k;
1138
1140 kargs.k_ptr,
1141 kargs.batch_stride_k, // kcache page-block stride/size
1142 fixed_offset,
1143 block_indices,
1144 num_blocks,
1145 kargs.page_block_size,
1146 k_dram,
1147 make_k_dram(nullptr,
1148 (kv_l2p_offset + kargs.seqlen_k) -
1149 (num_blocks - 1) * kargs.page_block_size));
1150 }
1151 else
1152 {
1153 return make_page_block_navigator(k_dram);
1154 }
1155 }();
1156
1157 auto v_page_block_navigator =
1158 [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1159 if constexpr(kIsPagedKV)
1160 {
1161 const auto* block_indices =
1162 reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1163 i_batch_ * kargs.batch_stride_block_table;
1164 const index_t num_blocks =
1165 integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1166
1167 const long_index_t fixed_offset =
1168 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_v;
1169
1171 kargs.v_ptr,
1172 kargs.batch_stride_v, // vcache page-block stride/size
1173 fixed_offset,
1174 block_indices,
1175 num_blocks,
1176 kargs.page_block_size,
1177 v_dram,
1178 make_v_dram(nullptr,
1179 (kv_l2p_offset + kargs.seqlen_k) -
1180 (num_blocks - 1) * kargs.page_block_size));
1181 }
1182 else
1183 {
1184 return make_page_block_navigator(v_dram);
1185 }
1186 }();
1187
1188 auto k_dram_window_lengths =
1190 auto v_dram_window_lengths =
1192
1195 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1196 constexpr auto bias_dram_window_lengths =
1199 {
1200 const BiasDataType* bias_ptr =
1201 reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1202 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1203 batch_offset_bias;
1204
1205 const auto bias_dram = [&]() {
1206 const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1207 bias_ptr,
1208 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1209 make_tuple(kargs.stride_bias, 1),
1211 number<1>{});
1212
1213 return pad_tensor_view(bias_dram_naive,
1214 bias_dram_window_lengths,
1216 }();
1217
1218 return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1219 }
1220 else
1221 {
1222 return make_null_tile_window(bias_dram_window_lengths);
1223 }
1224 }();
1225
1226 // lse
1227 auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1228 constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1229 if constexpr(kStoreLSE)
1230 {
1231 LSEDataType* lse_ptr =
1232 reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1233 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1234
1235 const auto lse_dram = [&]() {
1236 const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1237 lse_ptr,
1238 make_tuple(kargs.seqlen_q),
1239 make_tuple(1),
1240 number<1>{},
1241 number<1>{});
1242
1243 return pad_tensor_view(
1244 lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1245 }();
1246
1247 return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1248 }
1249 else
1250 {
1251 return make_null_tile_window(lse_dram_window_lengths);
1252 }
1253 }();
1254
1255 FmhaMask mask = [&]() {
1256 if constexpr(kHasMask)
1258 kargs.window_size_left,
1259 kargs.window_size_right,
1260 kargs.seqlen_q,
1261 kargs.seqlen_k,
1263 else
1264 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1265 }();
1266
1267 // WA i_batch capture structure binding before c++20
1268 auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1270 {
1271 // data loading, shared by entire wg
1272 // TODO: how to use s_read?
1273 SaccDataType slope =
1274 *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1275 i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1276#if CK_TILE_FMHA_FWD_FAST_EXP2
1277 slope *= ck_tile::log2e_v<>;
1278#endif
1279 if constexpr(kHasMask)
1280 {
1282 kargs.window_size_left,
1283 kargs.window_size_right,
1284 kargs.seqlen_q,
1285 kargs.seqlen_k,
1286 kargs.mask_type);
1287 }
1288 else
1289 {
1291 slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1292 }
1293 }
1294 else
1295 {
1297 }
1298 }();
1299
1300 AttentionVariant variant;
1301 const auto variant_params = [&] {
1302 if constexpr(kHasLogitsSoftCap)
1303 {
1305 mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1306 }
1307 else
1308 {
1309 return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1310 }
1311 }();
1312
1313 BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1314
1315 auto o_acc_tile = [&]() {
1316 if constexpr(kDoFp8StaticQuant)
1317 {
1318 return FmhaPipeline{}(
1319 q_dram_window,
1320 identity{}, // q_element_func
1321 k_dram_window_lengths,
1322 k_page_block_navigator,
1323 identity{}, // k_element_func
1324 v_dram_window_lengths,
1325 v_page_block_navigator,
1326 identity{}, // v_element_func
1327 bias_dram_window,
1328 identity{}, // bias_element_func
1329 lse_dram_window,
1330 identity{}, // lse_element_func
1331 identity{}, // s_acc_element_func
1332 scales{kargs.scale_p}, // p_compute_element_func
1333 composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1334 mask,
1335 position_encoding,
1336 kargs.scale_s,
1337 variant,
1338 variant_params,
1339 block_indices,
1340 kv_l2p_offset,
1341 smem_ptr);
1342 }
1343 else
1344 {
1345 return FmhaPipeline{}(q_dram_window,
1346 k_dram_window_lengths,
1347 k_page_block_navigator,
1348 v_dram_window_lengths,
1349 v_page_block_navigator,
1350 bias_dram_window,
1351 lse_dram_window,
1352 mask,
1353 position_encoding,
1354 kargs.scale_s,
1355 variant,
1356 variant_params,
1357 block_indices,
1358 kv_l2p_offset,
1359 smem_ptr);
1360 }
1361 }();
1362
1363 // O DRAM and O DRAM window
1364 auto o_dram = [&]() {
1366 o_ptr,
1367 make_tuple(kargs.seqlen_q, kargs.hdim_v),
1368 make_tuple(kargs.stride_o, 1),
1370 number<1>{});
1371 return pad_tensor_view(
1372 o_dram_naive,
1375 }();
1376
1377 auto o_dram_window =
1378 make_tile_window(o_dram,
1380 {i_m0, i_n1});
1381
1382 EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1383 }
1384};
1385
1386} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition page_block_navigator.hpp:333
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
@ FROM_BOTTOM_RIGHT
Definition block_position_encoding.hpp:43
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition block_position_encoding.hpp:48
Definition block_attention_bias_enum.hpp:19
Definition block_position_encoding.hpp:137
Definition fmha_fwd_pagedkv_kernel.hpp:281
ck_tile::index_t qo_head_idx
Definition fmha_fwd_pagedkv_kernel.hpp:283
ck_tile::index_t batch_idx
Definition fmha_fwd_pagedkv_kernel.hpp:282
ck_tile::index_t kv_head_idx
Definition fmha_fwd_pagedkv_kernel.hpp:284
Definition fmha_fwd_pagedkv_kernel.hpp:227
const int32_t * cache_batch_idx
Definition fmha_fwd_pagedkv_kernel.hpp:228
Definition fmha_fwd_pagedkv_kernel.hpp:215
ck_tile::index_t batch_stride_block_table
Definition fmha_fwd_pagedkv_kernel.hpp:217
const int32_t * block_table_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:216
ck_tile::index_t page_block_size
Definition fmha_fwd_pagedkv_kernel.hpp:218
Definition fmha_fwd_pagedkv_kernel.hpp:183
ck_tile::index_t alibi_slope_stride
Definition fmha_fwd_pagedkv_kernel.hpp:186
const void * alibi_slope_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:185
Definition fmha_fwd_pagedkv_kernel.hpp:178
ck_tile::index_t batch_stride_bias
Definition fmha_fwd_pagedkv_kernel.hpp:179
Definition fmha_fwd_pagedkv_kernel.hpp:243
ck_tile::index_t batch_stride_k
Definition fmha_fwd_pagedkv_kernel.hpp:247
ck_tile::index_t batch_stride_q
Definition fmha_fwd_pagedkv_kernel.hpp:246
const int32_t * seqlen_k_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:244
ck_tile::index_t batch_stride_o
Definition fmha_fwd_pagedkv_kernel.hpp:251
ck_tile::index_t batch_stride_v
Definition fmha_fwd_pagedkv_kernel.hpp:249
Definition fmha_fwd_pagedkv_kernel.hpp:171
ck_tile::index_t stride_bias
Definition fmha_fwd_pagedkv_kernel.hpp:173
const void * bias_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:172
ck_tile::index_t nhead_stride_bias
Definition fmha_fwd_pagedkv_kernel.hpp:174
Definition fmha_fwd_pagedkv_kernel.hpp:120
ck_tile::index_t hdim_v
Definition fmha_fwd_pagedkv_kernel.hpp:129
ck_tile::index_t seqlen_q
Definition fmha_fwd_pagedkv_kernel.hpp:126
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_pagedkv_kernel.hpp:143
ck_tile::index_t stride_o
Definition fmha_fwd_pagedkv_kernel.hpp:140
const void * k_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:122
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_pagedkv_kernel.hpp:142
ck_tile::index_t stride_v
Definition fmha_fwd_pagedkv_kernel.hpp:139
float scale_s
Definition fmha_fwd_pagedkv_kernel.hpp:135
ck_tile::index_t stride_k
Definition fmha_fwd_pagedkv_kernel.hpp:138
const void * v_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:123
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_pagedkv_kernel.hpp:144
ck_tile::index_t nhead_stride_o
Definition fmha_fwd_pagedkv_kernel.hpp:145
ck_tile::index_t hdim_q
Definition fmha_fwd_pagedkv_kernel.hpp:128
ck_tile::index_t seqlen_k
Definition fmha_fwd_pagedkv_kernel.hpp:127
const void * q_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:121
ck_tile::index_t num_head_q
Definition fmha_fwd_pagedkv_kernel.hpp:131
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_pagedkv_kernel.hpp:134
void * o_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:124
ck_tile::index_t stride_q
Definition fmha_fwd_pagedkv_kernel.hpp:137
Definition fmha_fwd_pagedkv_kernel.hpp:203
ck_tile::index_t batch_stride_lse
Definition fmha_fwd_pagedkv_kernel.hpp:206
void * lse_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:204
ck_tile::index_t nhead_stride_lse
Definition fmha_fwd_pagedkv_kernel.hpp:205
Definition fmha_fwd_pagedkv_kernel.hpp:113
Definition fmha_fwd_pagedkv_kernel.hpp:197
float scale_p
Definition fmha_fwd_pagedkv_kernel.hpp:198
float scale_o
Definition fmha_fwd_pagedkv_kernel.hpp:199
Definition fmha_fwd_pagedkv_kernel.hpp:267
const int32_t * seqlen_k_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:270
const int32_t * seqstart_k_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:269
ck_tile::index_t batch_stride_k
Definition fmha_fwd_pagedkv_kernel.hpp:272
ck_tile::index_t batch_stride_v
Definition fmha_fwd_pagedkv_kernel.hpp:274
const int32_t * seqstart_q_ptr
Definition fmha_fwd_pagedkv_kernel.hpp:268
float logits_soft_cap
Definition fmha_fwd_pagedkv_kernel.hpp:166
void init_logits_soft_cap(float logits_soft_cap_)
Definition fmha_fwd_pagedkv_kernel.hpp:152
float logits_soft_cap_rcp
Definition fmha_fwd_pagedkv_kernel.hpp:167
Definition fmha_fwd_pagedkv_kernel.hpp:190
ck_tile::index_t window_size_left
Definition fmha_fwd_pagedkv_kernel.hpp:192
ck_tile::index_t window_size_right
Definition fmha_fwd_pagedkv_kernel.hpp:192
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_fwd_pagedkv_kernel.hpp:193
Definition fmha_fwd_pagedkv_kernel.hpp:210
ck_tile::index_t min_seqlen_q
Definition fmha_fwd_pagedkv_kernel.hpp:211
Definition fmha_fwd_pagedkv_kernel.hpp:222
bool is_gappy
Definition fmha_fwd_pagedkv_kernel.hpp:223
static constexpr const char * name
Definition fmha_fwd_pagedkv_kernel.hpp:69
static constexpr const char * name
Definition fmha_fwd_pagedkv_kernel.hpp:71
static constexpr const char * name
Definition fmha_fwd_pagedkv_kernel.hpp:68
static constexpr const char * name
Definition fmha_fwd_pagedkv_kernel.hpp:70
static constexpr const char * name
Definition fmha_fwd_pagedkv_kernel.hpp:67
Definition fmha_fwd_pagedkv_kernel.hpp:66
Definition fmha_fwd_pagedkv_kernel.hpp:28
static constexpr bool kIsGroupMode
Definition fmha_fwd_pagedkv_kernel.hpp:47
static constexpr bool kHasLogitsSoftCap
Definition fmha_fwd_pagedkv_kernel.hpp:52
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition fmha_fwd_pagedkv_kernel.hpp:503
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_pagedkv_kernel.hpp:30
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_pagedkv_kernel.hpp:32
static constexpr bool kStoreLSE
Definition fmha_fwd_pagedkv_kernel.hpp:54
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_pagedkv_kernel.hpp:74
static constexpr bool kPadSeqLenK
Definition fmha_fwd_pagedkv_kernel.hpp:49
static constexpr bool kIsPagedKV
Definition fmha_fwd_pagedkv_kernel.hpp:57
static constexpr bool kSkipMinSeqlenQ
Definition fmha_fwd_pagedkv_kernel.hpp:56
static CK_TILE_HOST void PrintParameters(const Kargs &kargs, int num_batches)
Definition fmha_fwd_pagedkv_kernel.hpp:706
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_pagedkv_kernel.hpp:31
static constexpr bool kPadHeadDimV
Definition fmha_fwd_pagedkv_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_pagedkv_kernel.hpp:41
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition fmha_fwd_pagedkv_kernel.hpp:627
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_pagedkv_kernel.hpp:868
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_pagedkv_kernel.hpp:39
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_fwd_pagedkv_kernel.hpp:803
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_fwd_pagedkv_kernel.hpp:35
static constexpr bool kUseAsyncCopy
Definition fmha_fwd_pagedkv_kernel.hpp:63
static constexpr bool kHasMask
Definition fmha_fwd_pagedkv_kernel.hpp:61
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_pagedkv_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_pagedkv_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition fmha_fwd_pagedkv_kernel.hpp:40
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_pagedkv_kernel.hpp:880
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_pagedkv_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_fwd_pagedkv_kernel.hpp:45
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_pagedkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition fmha_fwd_pagedkv_kernel.hpp:43
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition fmha_fwd_pagedkv_kernel.hpp:416
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_fwd_pagedkv_kernel.hpp:60
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition fmha_fwd_pagedkv_kernel.hpp:289
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_pagedkv_kernel.hpp:885
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition fmha_fwd_pagedkv_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_pagedkv_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_pagedkv_kernel.hpp:37
static constexpr bool kDoFp8StaticQuant
Definition fmha_fwd_pagedkv_kernel.hpp:55
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition fmha_fwd_pagedkv_kernel.hpp:278
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k)
Definition fmha_fwd_pagedkv_kernel.hpp:778
static constexpr auto BiasEnum
Definition fmha_fwd_pagedkv_kernel.hpp:53
Definition variants.hpp:63
Definition variants.hpp:51
Definition tile/core/utility/functional.hpp:86
Definition unary_element_function.hpp:56
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49