block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Source File

block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Source File#

Composable Kernel: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Source File
block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
13
14namespace ck_tile {
15
16// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
17template <typename Problem_,
20{
36
39 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
40 static_assert(kQLoadOnce == Policy::QLoadOnce);
41
42 static constexpr index_t kBlockSize = Problem::kBlockSize;
43
44 static constexpr index_t kM0 = BlockFmhaShape::kM0;
45 static constexpr index_t kN0 = BlockFmhaShape::kN0;
46 static constexpr index_t kK0 = BlockFmhaShape::kK0;
47 static constexpr index_t kN1 = BlockFmhaShape::kN1;
48 static constexpr index_t kK1 = BlockFmhaShape::kK1;
49 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
50 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
51 static constexpr auto I0 = number<0>{};
52 static constexpr auto I1 = number<1>{};
53 static constexpr auto I2 = number<2>{};
54 static constexpr auto I3 = number<3>{};
55
56 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
57
58 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
59 // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
60 // only need special care about seq_k padding (oob need set -INF of p instead of zero)
61 static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
62 Problem::kPadHeadDimV == true);
63 static constexpr bool kPadSeqLenQ = true;
64 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
65 static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
66 static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
67 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
68 static constexpr auto BiasEnum = Problem::BiasEnum;
69 static constexpr bool kStoreLSE = Problem::kStoreLSE;
70 static constexpr bool kHasDropout = Problem::kHasDropout;
71
72 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
73 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
76
77 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
78 // ... together with tensor distribution. tensor dist should able to overwrite this
79 static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
80 static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
81 static constexpr index_t kAlignmentV = []() {
82 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
83 return Policy::template GetAlignmentV<Problem>();
84 else
85 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
86 }();
87 static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
88 static constexpr index_t kAlignmentBias =
89 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
90
91#if CK_TILE_FMHA_FWD_FAST_EXP2
92 static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
93#endif
94
95 static constexpr index_t kBlockPerCu = []() {
96 if constexpr(Problem::kBlockPerCu != -1)
97 return Problem::kBlockPerCu;
98 else
99 {
100 // minimize occupancy
102 {
103 return 1;
104 }
105
106 if constexpr(kQKHeaddim <= 32)
107 {
109 FmhaMask::IsMasking)
110 return 1;
111 else
112 return 2;
113 }
114 else if constexpr(kQKHeaddim <= 64)
115 {
117 return 2;
118 else
119 return 3;
120 }
121 else if constexpr(kQKHeaddim <= 128)
122 {
124 return 1;
125 else
126 return 2;
127 }
128 else if constexpr(kQKHeaddim <= 192)
129 {
131 return 1;
132 else
133 return 2;
134 }
135 else if constexpr(kQKHeaddim <= 256)
136 {
137 return 1;
138 }
139 else
140 {
141 return 1;
142 };
143 }
144 }();
145
146 static constexpr const char* name = "qr_async";
147
148 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
149
151 {
152 return Policy::template GetSmemSize<Problem>();
153 }
154
155 template <typename QDramBlockWindowTmp,
156 typename KDramBlockWindowTmp,
157 typename VDramBlockWindowTmp,
158 typename BiasDramBlockWindowTmp,
159 typename RandValDramBlockWindowTmp,
160 typename LSEDramBlockWindowTmp,
161 typename QElementFunction,
162 typename KElementFunction,
163 typename VElementFunction,
164 typename BiasElementFunction,
165 typename LSEElementFunction,
166 typename SAccElementFunction,
167 typename PComputeElementFunction,
168 typename OAccElementFunction,
169 typename PositionEncoding,
170 typename AttentionVariantParams,
171 typename BlockIndices>
173 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
174 const QElementFunction& q_element_func,
175 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
176 const KElementFunction& /*k_element_func*/,
177 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
178 const VElementFunction& v_element_func,
179 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
180 const BiasElementFunction& bias_element_func,
181 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
182 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
183 const LSEElementFunction& lse_element_func,
184 const SAccElementFunction& s_acc_element_func,
185 const PComputeElementFunction& p_compute_element_func,
186 const OAccElementFunction& o_acc_element_func,
187 FmhaMask mask,
188 PositionEncoding position_encoding,
189 float scale_s,
190 const AttentionVariant& variant,
191 const AttentionVariantParams& variant_params,
192 const BlockIndices& block_indices,
193 void* smem_ptr,
194 const index_t* page_idx,
195 const index_t stride_k,
196 const index_t stride_v,
197 DropoutType& dropout) const
198 {
199 static_assert(
200 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
201 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
202 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
203 "wrong!");
204
205 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
206 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
207 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
208 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
209 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
210 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
211 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
212 "wrong!");
213
214 constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
215
216 // K tile in LDS
217 auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
218 auto k_lds_store = generate_tuple(
219 [&](auto i_buf) {
220 return make_tile_window(
222 k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
223 Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
224 {0, 0, 0});
225 },
227
228 auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
229 k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
230
231 auto k_lds_load =
232 make_tile_window(k_lds_Load_view,
233 Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
234 {0, 0});
235
236 // V tile in LDS
238 reinterpret_cast<VDataType*>(smem_ptr),
239 Policy::template MakeVLdsBlockDescriptor<Problem>());
240 auto v_lds_window = make_tile_window(
241 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
242
243 // Block GEMM
244 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
245 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
246
247 auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
248 q_dram_block_window_tmp.get_window_lengths(),
249 q_dram_block_window_tmp.get_window_origin(),
250 Policy::template MakeQRegTileDistribution<Problem>());
251 q_dram_window.init_raw();
252
253 // TODO: we use async Copy for K, which is inline asm
254 // a side effect is we have to use inline asm for q as well
255 auto q = decltype(load_tile(q_dram_window)){};
256 // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
257 // however, q would be cleared in the constructor of static distributed tensor
258 // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
259 load_tile_raw(q, q_dram_window);
260 __builtin_amdgcn_sched_barrier(0);
261
262 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
263 auto s_acc = SaccBlockTileType{};
264
265 // reduction function for softmax
266 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
267 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
268
269 // infer Sacc, S, P, M, L, Oacc type
270 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
271
272 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
273 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
274
275 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
276
277 // init Oacc, M, L
278 auto o_acc = OaccBlockTileType{};
279 auto m = MLBlockTileType{};
280 auto l = MLBlockTileType{};
281
282 clear_tile(o_acc);
284 clear_tile(l);
285
286 __builtin_amdgcn_sched_barrier(0);
287 const auto q_origin = q_dram_window.get_window_origin();
288 const auto [seqlen_k_start, seqlen_k_end] =
289 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
290
291 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
292
293 // check early exit if no work to do
294 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
295 {
296 if(num_total_loop <= 0)
297 {
298 if constexpr(kStoreLSE)
299 {
300 auto lse =
301 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
302
304
305 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
306 }
307 buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
308 // otherwise will have compute error(maybe compiler bug?)
309
310 // Note: here occ are all cleard, return it
311 return o_acc;
312 }
313 __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
314 }
315
316 auto k_dram_block_window =
317 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
318 k_dram_block_window_tmp.get_window_lengths(),
319 {seqlen_k_start, 0});
320
321 auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
322 auto k_coord = k_dist.calculate_index();
323 using KDstrEncode = typename decltype(k_dist)::DstrEncode;
324 constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
326 static_for<0, NRepeat, 1>{}([&](auto n0) {
327 k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
328 });
329 auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
330 k_dram_block_window.get_window_lengths(),
331 k_dram_block_window.get_window_origin(),
332 k_dist,
333 k_offsets); // K DRAM tile window for
334 k_dram_window.init_raw();
335 constexpr auto k_oob_ck = bool_constant<true>{};
336 constexpr auto k_pre_np = [&]() {
337 if constexpr(kPadSeqLenK &&
340 return bool_constant<true>{};
341 else
342 return bool_constant<false>{};
343 }();
344
345 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
346 auto bias_dram_window =
347 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
348 bias_dram_block_window_tmp.get_window_lengths(),
349 {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
350 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
351
352 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
353 randval_dram_block_window_tmp, seqlen_k_start);
354
355 auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
356 auto v_coord = v_dist.calculate_index();
357 const auto VPageIndexDim = I1;
358 using VDstrEncode = typename decltype(v_dist)::DstrEncode;
359 constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
361 (void)stride_k;
362 static_for<0, V_KRepeat, 1>{}([&](auto k0) {
363 v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
364 });
365
366 auto v_dram_window =
367 make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
368 v_dram_block_window_tmp.get_window_lengths(),
369 {0, seqlen_k_start}, // TODO: hdim split?
370 v_dist,
371 v_offsets,
372 VPageIndexDim);
373
374 // prefetch K tile
376 k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
377 move_tile_window(k_dram_window, {0, kK0});
378 __builtin_amdgcn_sched_barrier(0);
379
380 buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
381 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
382 // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
383
384 index_t i_total_loops = 0;
385 constexpr index_t k0_loops = kQKHeaddim / kK0;
386 constexpr index_t k1_loops = kN0 / kK1;
387
388 static_assert(1 <= k0_loops);
389 static_assert(1 <= k1_loops);
390 // main loop
391 do
392 {
393 // STAGE 1, QK gemm
394 clear_tile(s_acc); // initialize C
395 if constexpr(k0_loops > 1)
396 {
397 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
398 async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
399 k_dram_window,
400 number<-1>{},
401 k_oob_ck,
402 k_pre_np);
403 if constexpr(i_k0 < k0_loops - 1)
404 move_tile_window(k_dram_window, {0, kK0});
405
406 async_load_fence(k_dram_window.get_num_of_access());
407 __builtin_amdgcn_s_barrier();
408 __builtin_amdgcn_sched_barrier(0);
409 gemm_0(s_acc,
411 q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
412 get_slice_tile(k_lds_load,
413 sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
414 sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
415 });
416 }
417
418 // TODO: this to fix a bug when loop smaller than 2,
419 // the following fence/barrier will be scheduled inside 1st loop
420 if constexpr(k0_loops <= 2)
421 __builtin_amdgcn_sched_barrier(0);
422
424 __builtin_amdgcn_s_barrier();
425
426 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
427 auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
428 static_for<0, V_KRepeat, 1>{}([&](auto k0) {
429 v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
430 });
431 v_dram_window.update_page_idx(v_offsets);
432
433 __builtin_amdgcn_sched_barrier(0);
434 { // tail
435 gemm_0(
436 s_acc,
438 q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
439 get_slice_tile(k_lds_load,
440 sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
441 sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
442 }
443 __builtin_amdgcn_sched_barrier(1);
444
445 // STAGE 2, scale_s, add bias, mask, softmax
447 {
448 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
449 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
451 [&](auto& x, const auto& y) {
452#if !CK_TILE_FMHA_FWD_FAST_EXP2
453 x += type_convert<SaccDataType>(bias_element_func(y));
454#else
456 type_convert<SaccDataType>(bias_element_func(y));
457#endif
458 },
459 s_acc,
460 bias_tile);
461 }
462 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
463 {
464 const auto k_origin = k_dram_block_window.get_window_origin();
465 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
466 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
467 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
468 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
469 const auto tile_idx = get_x_indices_from_distributed_indices(
470 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
471
472 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
473 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
474 constexpr auto i_j_idx = make_tuple(idx0, idx1);
475
476 s_acc(i_j_idx) *= scale_s;
477 position_encoding.update(s_acc(i_j_idx), row, col);
478 });
479 });
480 }
481 else
482 {
483 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
484 if constexpr(kHasLogitsSoftCap)
485 {
486 auto apply_logits_transform =
487 [&variant, &variant_params, &block_indices](auto& x) {
488 x = variant.LogitsTransform(variant_params,
489 variant.QueryTransform(variant_params, x),
490 block_indices.batch_idx,
491 block_indices.qo_head_idx,
492 block_indices.kv_head_idx);
493 };
494#if !CK_TILE_FMHA_FWD_FAST_EXP2
495 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
496 {
497 apply_logits_transform(s_acc.thread_buf_[i]);
498 }
499#else
500 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
501 {
502#if(defined(__gfx90a__) || defined(__gfx94__)) && \
503 (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
504 CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
505 // Avoid data hazard if v_mfma is followed by inline asm consumer
506 // instructions. In this case, compiler won't add s_nop for us
507 if(i == s_acc.thread_buf_.size() / 2)
508 {
509 __builtin_amdgcn_sched_barrier(0);
510 }
511#endif
512 apply_logits_transform(s_acc.thread_buf_[i]);
513 }
514#endif
515 }
516 else
517 {
518#if !CK_TILE_FMHA_FWD_FAST_EXP2
519 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
520#endif
521 }
522 }
523 move_tile_window(bias_dram_window, {0, kN0});
524 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
525 {
526 const auto k_origin = k_dram_block_window.get_window_origin();
527 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
528 k_origin.at(number<0>{}),
529 number<kM0>{},
530 number<kN0>{});
531
532 if(need_perpixel_check)
533 {
535 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
536 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
537 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
538 return !variant.LogitsMask(variant_params,
539 block_indices.batch_idx,
540 row,
541 col,
542 block_indices.qo_head_idx,
543 block_indices.kv_head_idx);
544 });
545 }
546 }
547
548 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
550 s,
551 sequence<1>{},
552 f_max,
553 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
555
556 const auto m_old = m; // m{j-1}
558 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
559
561 s.get_tile_distribution()); // Pcompute{j}
562
563 __builtin_amdgcn_sched_barrier(0x7F);
564 // store & prefetch next v, after the max reduction
565 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
566 {
568 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
569 shuffle_tile(v_shuffle_tmp, v_buf);
570
571 auto v_lds_window_tmp =
572 get_slice_tile(v_lds_window,
573 sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
574 sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
575
577 v_lds_window_tmp,
578 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
579 }
580 else
581 {
582 auto v_lds_window_tmp =
583 get_slice_tile(v_lds_window,
584 sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
585 sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
586 store_tile(v_lds_window_tmp,
587 tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
588 }
589
590 if constexpr(k1_loops > 1)
591 {
593 v_dram_window,
594 {0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
595 v_buf = load_tile(
596 v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
597 static_for<0, V_KRepeat, 1>{}([&](auto k0) {
598 v_offsets[k0] =
599 page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
600 });
601 v_dram_window.update_page_idx(v_offsets);
602 }
603 __builtin_amdgcn_sched_barrier(0);
604
605 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
609 FmhaMask::IsMasking)
610 {
613 : raw_m;
614 }
615 else
616 {
617 return raw_m;
618 }
619 };
620
621 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
622 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
623 constexpr auto i_idx = make_tuple(idx0);
624#if CK_TILE_FMHA_FWD_FAST_EXP2
625 auto row_max = scale_s * get_validated_m(m[i_idx]);
626#endif
627 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
628 constexpr auto i_j_idx = make_tuple(idx0, idx1);
629#if CK_TILE_FMHA_FWD_FAST_EXP2
632 {
633 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
634 }
635 else
636 {
637 if constexpr(kHasLogitsSoftCap)
638 {
639 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
640 }
641 else
642 {
643 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
644 }
645 }
646#else
647 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
648#endif
649 });
650 });
651
653 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
654
656 // l{j}, Oacc{j}
657 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
658 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
659 constexpr auto i_idx = make_tuple(idx0);
660#if CK_TILE_FMHA_FWD_FAST_EXP2
661 const auto tmp = [&]() {
664 {
665 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
666 }
667 else
668 {
669 if constexpr(kHasLogitsSoftCap)
670 {
671 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
672 }
673 else
674 {
675 auto row_max = scale_s * get_validated_m(m[i_idx]);
676 return exp2(scale_s * m_old[i_idx] - row_max);
677 }
678 }
679 }();
680#else
681 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
682#endif
683 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
684 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
685 constexpr auto i_j_idx = make_tuple(idx0, idx1);
686 // FIXME: this use different equation from FA v2 paper,
687 // but produce correc result.
688 // Is the equation wrong?
689 o_acc(i_j_idx) *= tmp;
690 });
691 });
692
693 if constexpr(kHasDropout)
694 {
695 auto randval_ptr =
696 reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
697 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
698 randval_ptr,
699 seqlen_k_start + i_total_loops * kN0,
700 p_compute,
701 randval_dram_window);
702 }
703
704 const auto p = [&]() {
705#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
706 // For fp32 to fp16,
707 // impl::cast_tile_pk_fp16_fp32 would cause precision issue,
708 // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
709 return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
710#else
711 if constexpr(std::is_same_v<PDataType, fp16_t>)
713 tile_elementwise_in(p_compute_element_func, p_compute));
714 else
716 tile_elementwise_in(p_compute_element_func, p_compute));
717#endif
718 }();
719
720 // STAGE 3, KV gemm
721 if constexpr(k1_loops > 1)
722 {
723 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
724 if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
725 {
726 v_buf = load_tile(
727 v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
728 static_for<0, V_KRepeat, 1>{}([&](auto k0) {
729 v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
730 v_coord[VPageIndexDim] + k0.value] *
731 stride_v;
732 });
733 v_dram_window.update_page_idx(v_offsets);
734 }
736 gemm_1(o_acc,
738 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
740 v_lds_window,
741 sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
742 sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
743
744 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
745 {
747 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
748 shuffle_tile(v_shuffle_tmp, v_buf);
749 auto v_lds_window_tmp = get_slice_tile(
750 v_lds_window,
751 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
752 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
753 store_tile(v_lds_window_tmp,
754 tile_elementwise_in(v_element_func,
755 v_shuffle_tmp)); // store the prefetch
756 }
757 else
758 {
759 auto v_lds_window_tmp = get_slice_tile(
760 v_lds_window,
761 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
762 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
763 store_tile(v_lds_window_tmp,
764 tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
765 }
766 if constexpr(i_k1 < k1_loops - 1)
767 move_tile_window(v_dram_window, {0, kK1});
768 });
769 }
770 i_total_loops++;
771 if(i_total_loops < num_total_loop)
772 {
773 page_idx += kN0;
774 // move K tile windows
775 move_tile_window(k_dram_block_window, {kN0, 0});
776 k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
777
778 static_for<0, NRepeat, 1>{}([&](auto n0) {
779 k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
780 });
781 k_dram_window.update_page_idx(k_offsets);
782 if constexpr(k1_loops >= 2 &&
783 LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
784 __builtin_amdgcn_s_barrier();
785 async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
786 k_dram_window,
787 number<-1>{},
788 k_oob_ck,
789 k_pre_np);
790 move_tile_window(k_dram_window, {0, kK0});
791 }
792 // tail
793 {
795 gemm_1(
796 o_acc,
797 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
799 v_lds_window,
800 sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
801 sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
802 }
803 } while(i_total_loops < num_total_loop);
804
805 // store lse
806 if constexpr(kStoreLSE)
807 {
808 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
809
810 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
811 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
812 constexpr auto i_idx = make_tuple(idx0);
813#if CK_TILE_FMHA_FWD_FAST_EXP2
816 {
817 lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
818 }
819 else
820 {
821 if constexpr(kHasLogitsSoftCap)
822 {
823 lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
824 }
825 else
826 {
827 lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
828 }
829 }
830#else
831 lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
832#endif
833 });
834
835 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
836 }
837
838 // finally, O
839 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
840
841 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
842 constexpr auto i_idx = make_tuple(idx0);
843 const auto tmp = [&]() {
844 if constexpr(FmhaMask::IsMasking)
845 {
846 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
847 }
848 else
849 return 1 / l[i_idx];
850 }();
851 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
852 constexpr auto i_j_idx = make_tuple(idx0, idx1);
853 o_acc(i_j_idx) *= tmp;
854 });
855 });
856
857 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
858
859 return o_acc;
860 }
861
862 template <typename QDramBlockWindowTmp,
863 typename KDramBlockWindowTmp,
864 typename VDramBlockWindowTmp,
865 typename BiasDramBlockWindowTmp,
866 typename RandValDramBlockWindowTmp,
867 typename LSEDramBlockWindowTmp,
868 typename PositionEncoding,
869 typename AttentionVariantParams,
870 typename BlockIndices>
872 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
873 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
874 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
875 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
876 RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
877 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
878 FmhaMask mask,
879 PositionEncoding position_encoding,
880 float scale_s,
881 const AttentionVariant& variant,
882 const AttentionVariantParams& variant_params,
883 const BlockIndices& block_indices,
884 void* smem_ptr,
885 const index_t* page_idx,
886 const index_t stride_k,
887 const index_t stride_v,
888 DropoutType& dropout) const
889 {
890 return operator()(q_dram_block_window_tmp,
891 identity{},
892 k_dram_block_window_tmp,
893 identity{},
894 v_dram_block_window_tmp,
895 identity{},
896 bias_dram_block_window_tmp,
897 identity{},
898 randval_dram_block_window_tmp,
899 lse_dram_block_window_tmp,
900 identity{},
901 identity{},
902 identity{},
903 identity{},
904 mask,
905 position_encoding,
906 scale_s,
907 variant,
908 variant_params,
909 block_indices,
910 smem_ptr,
911 page_idx,
912 stride_k,
913 stride_v,
914 dropout);
915 }
916};
917
918} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor &in_dstr_tensors)
Definition tile_elementwise.hpp:231
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE auto async_load_fence(index_t cnt=0)
Definition load_tile.hpp:145
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition load_tile.hpp:81
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void buffer_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:815
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition tile_scatter_gather.hpp:906
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
BlockFmhaPipelineQXKSVSCustomPolicy< true, true, 3, 3 > BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp:12
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:20
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:28
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:33
static constexpr index_t kK1
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:48
static constexpr index_t kAlignmentV
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:81
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:24
static constexpr index_t kN1
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:47
static constexpr auto I1
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:52
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:37
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:35
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, const index_t *page_idx, const index_t stride_k, const index_t stride_v, DropoutType &dropout) const
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:173
static constexpr index_t kM0
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:44
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:34
static constexpr index_t kAlignmentO
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:87
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:67
static constexpr auto BiasEnum
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:68
remove_cvref_t< Policy_ > Policy
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:22
static constexpr auto I0
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:51
static constexpr index_t kSubQKHeaddim
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:50
static constexpr bool kHasDropout
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:70
static constexpr index_t kBlockSize
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:42
static constexpr bool kPadHeadDimQ
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:65
static constexpr index_t kQKHeaddim
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:49
static constexpr index_t kAlignmentK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:80
static constexpr index_t kN0
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:45
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:32
static constexpr index_t kBlockPerCu
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:95
static constexpr bool kStoreLSE
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:69
static constexpr auto I3
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:54
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:23
static constexpr bool kPadHeadDimV
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:66
static constexpr bool kPadSeqLenQ
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:63
static constexpr index_t kAlignmentBias
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:88
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:38
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:148
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, const index_t *page_idx, const index_t stride_k, const index_t stride_v, DropoutType &dropout) const
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:872
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:26
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:150
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:29
static constexpr const char * name
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:146
static constexpr index_t kK0
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:46
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:25
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:31
remove_cvref_t< Problem_ > Problem
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:21
static constexpr index_t kAlignmentQ
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:79
static constexpr auto I2
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:53
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:27
static constexpr bool kQLoadOnce
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:39
static constexpr bool kIsGroupMode
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:58
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:30
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43