blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack,
100 true>
101
102{
104 ADataType,
105 BDataType,
106 ComputeDataType,
107 AccDataType,
108 ATileDesc,
109 BTileDesc,
110 AMmaTileDesc,
111 BMmaTileDesc,
112 ABlockTransferSrcScalarPerVector,
113 BBlockTransferSrcScalarPerVector,
114 MPerBlock,
115 NPerBlock,
116 KPerBlock,
117 MPerXDL,
118 NPerXDL,
119 MRepeat,
120 NRepeat,
121 KPack,
122 true>;
123 using Base::I0;
124 using Base::KRepeat;
125 using Base::xdlops_gemm;
126 using typename Base::HotLoopInstList;
127
139
142
143 using Base::AMmaKStride;
144 using Base::BMmaKStride;
145
147
148 static constexpr index_t PrefetchStages = 2;
149 static constexpr index_t PrefillStages = 1;
150 static constexpr index_t GlobalBufferNum = 1;
151
152 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
153 {
154 return num_loop > PrefetchStages;
155 }
156
157 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
158 {
159 ignore = num_loop;
160 return TailNumber::Full;
161 }
162
163 __device__ static constexpr auto HotLoopScheduler()
164 {
165 // A/B split schedule
166 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
167 constexpr auto num_ds_read_inst_a =
168 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
171 constexpr auto num_ds_read_inst_b =
172 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
175
176 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
177 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
178
179 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
180 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
181
182 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
183
184 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
185 constexpr auto ds_read_a_issue_cycle =
186 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
187 constexpr auto ds_read_b_issue_cycle =
188 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
189 constexpr auto ds_read_a_mfma_rate =
190 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
191 constexpr auto ds_read_b_mfma_rate =
192 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
193
194 constexpr auto num_dsread_a_mfma =
195 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
196 constexpr auto num_dsread_b_mfma =
197 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
198
199 // stage 1
200 // Separate this part?
201 // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
202 // sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
203 // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
204 // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
205 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
206 constexpr auto num_mfma_per_issue =
207 num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
208 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
209 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
210
212 ignore = i;
213 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
214 ignore = idswrite;
215 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
216 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
217 });
218 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
219 __builtin_amdgcn_sched_group_barrier(
220 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
221 });
223 ignore = i;
224 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225 ignore = idswrite;
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228 });
229 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
230 __builtin_amdgcn_sched_group_barrier(
231 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
232 });
233
234 // stage 2
236 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
237 ds_read_a_mfma_rate)
238 {
239 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
240 }
241 else
242 {
243 __builtin_amdgcn_sched_group_barrier(0x100,
244 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
245 ds_read_a_mfma_rate,
246 0); // DS read
247 }
248 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
249 });
250
252 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
253 ds_read_b_mfma_rate)
254 {
255 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
256 }
257 else
258 {
259 __builtin_amdgcn_sched_group_barrier(0x100,
260 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
261 ds_read_b_mfma_rate,
262 0); // DS read
263 }
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 });
266 }
267
268 template <bool HasMainLoop,
269 int NumKBlockPerScale,
270 TailNumber TailNum,
271 typename AGridDesc,
272 typename ABlockDesc,
273 typename ABlockTransfer,
274 typename AGridBuffer,
275 typename ABlockBuffer,
276 typename ABlockTransferStep,
277 typename BGridDesc,
278 typename BBlockDesc,
279 typename BBlockTransfer,
280 typename BGridBuffer,
281 typename BBlockBuffer,
282 typename BBlockTransferStep,
283 typename CScaleThreadDesc,
284 typename CThreadBuffer,
285 typename AScaleGridBuffer,
286 typename AScaleGridDesc,
287 typename AScaleThreadDesc,
288 typename AScaleThreadTransfer,
289 typename AScaleThreadTransferStep,
290 typename BScaleGridBuffer,
291 typename BScaleGridDesc,
292 typename BScaleThreadDesc,
293 typename BScaleThreadTransfer,
294 typename BScaleThreadTransferStep>
295 __device__ void Run(
296 // ABlockCopy
297 const AGridDesc& a_grid_desc,
298 const ABlockDesc& a_block_desc,
299 ABlockTransfer& a_blockwise_copy,
300 const AGridBuffer& a_grid_buf,
301 ABlockBuffer& a_block_buf,
302 const ABlockTransferStep& a_block_copy_step,
303 // BBlockCopy
304 const BGridDesc& b_grid_desc,
305 const BBlockDesc& b_block_desc,
306 BBlockTransfer& b_blockwise_copy,
307 const BGridBuffer& b_grid_buf,
308 BBlockBuffer& b_block_buf,
309 const BBlockTransferStep& b_block_copy_step,
310 // CThread
311 const CScaleThreadDesc& c_scale_thread_desc,
312 CThreadBuffer& c_thread_buf,
313 // AScaleThreadCopy
314 const AScaleGridDesc& a_scale_grid_desc,
315 const AScaleThreadDesc& a_scale_thread_desc,
316 AScaleThreadTransfer& a_scale_thread_copy,
317 const AScaleGridBuffer& a_scale_grid_buf,
318 const AScaleThreadTransferStep& a_scale_thread_copy_step,
319 // BScaleThreadCopy
320 const BScaleGridDesc& b_scale_grid_desc,
321 const BScaleThreadDesc& b_scale_thread_desc,
322 BScaleThreadTransfer& b_scale_thread_copy,
323 const BScaleGridBuffer& b_scale_grid_buf,
324 const BScaleThreadTransferStep& b_scale_thread_copy_step,
325 // num_loop
326 index_t num_loop) const
327 {
328 __builtin_amdgcn_sched_barrier(0);
329 static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1,
330 "Pipeline v3 only support scaleblocksliceK=1");
331 static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
332 "Pipeline v3 only support scaleblocksliceN=1");
333 // assume kperblock = scaleblockk
335 a_thread_desc_.GetElementSpaceSize());
337 b_thread_desc_.GetElementSpaceSize());
339 a_scale_thread_desc.GetElementSpaceSize());
341 b_scale_thread_desc.GetElementSpaceSize());
343 c_scale_thread_desc.GetElementSpaceSize());
344
345 // Global prefetch 1
346 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
347 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
348
349 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
350 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
351
352 static_for<0, MRepeat, 1>{}([&](auto m0) {
353 a_scale_thread_copy.Run(a_scale_grid_desc,
354 a_scale_grid_buf,
355 a_scale_thread_desc,
356 make_tuple(m0, I0),
357 a_scale_thread_buf);
358 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
359 a_scale_thread_copy_step.At(Number<0>{}));
360 });
361
362 if constexpr(NumKBlockPerScale == 1)
363 {
364 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
365 a_scale_thread_copy_step.At(Number<2>{}));
366 }
367 else
368 {
369 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
370 a_scale_thread_copy_step.At(Number<1>{}));
371 }
372
373 b_scale_thread_copy.Run(b_scale_grid_desc,
374 b_scale_grid_buf,
375 b_scale_thread_desc,
376 make_tuple(I0, I0),
377 b_scale_thread_buf);
378
379 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
380
381 static_for<0, MRepeat, 1>{}([&](auto m0) {
382 c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
383 });
384
385 // Local prefill 1
386 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
387 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
388
389 // Global prefetch 2
390 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
391 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
392
393 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
394 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
395
396 static_for<0, MRepeat, 1>{}([&](auto m0) {
397 a_scale_thread_copy.Run(a_scale_grid_desc,
398 a_scale_grid_buf,
399 a_scale_thread_desc,
400 make_tuple(m0, I0),
401 a_scale_thread_buf);
402 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
403 a_scale_thread_copy_step.At(Number<0>{}));
404 });
405
406 if constexpr(NumKBlockPerScale == 1)
407 {
408 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
409 a_scale_thread_copy_step.At(Number<2>{}));
410 }
411 else
412 {
413 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
414 a_scale_thread_copy_step.At(Number<1>{}));
415 }
416
417 b_scale_thread_copy.Run(b_scale_grid_desc,
418 b_scale_grid_buf,
419 b_scale_thread_desc,
420 make_tuple(I0, I0),
421 b_scale_thread_buf);
422
423 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
424
425 // Initialize C
426 c_thread_buf.Clear();
427
429 AccDataType,
430 1,
431 xdlops_gemm.GetRegSizePerXdlops(),
432 true>
433 c_thread_buf_per_scale;
434
435 // Local prefetch 1
437 static_for<0, KRepeat, 1>{}([&](auto k0) {
438 static_for<0, MRepeat, 1>{}([&](auto m0) {
441 a_block_buf,
443 make_tuple(m0, I0, k0, I0),
444 a_thread_buf);
445 });
446 static_for<0, NRepeat, 1>{}([&](auto n0) {
449 b_block_buf,
451 make_tuple(n0, I0, k0, I0),
452 b_thread_buf);
453 });
454 });
455
456 __builtin_amdgcn_sched_barrier(0);
457
458 // main body
459 if constexpr(HasMainLoop)
460 {
461 index_t i = 0;
462 do
463 {
465 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
466 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
467
468 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
469 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
470
471 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
472 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
473
474 static_for<0, MRepeat, 1>{}([&](auto m0) {
475 static_for<0, NRepeat, 1>{}([&](auto n0) {
476 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
477 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
478 .template AsType<AccDataType>()(Number<t>{}) = 0;
479 });
480 static_for<0, KRepeat, 1>{}([&](auto k0) {
483
484 static_for<0, KPack, 1>{}([&](auto ik) {
485 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
486 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
487 make_tuple(m0, I0, k0, ik))>{}];
488 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
489 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
490 make_tuple(n0, I0, k0, ik))>{}];
491 });
492
493 using mfma_input_type =
495 xdlops_gemm.K1PerXdlops>::type;
496
497 xdlops_gemm.template Run<>(
498 a_thread_vec.template AsType<mfma_input_type>(),
499 b_thread_vec.template AsType<mfma_input_type>(),
500 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
501 });
502 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
503 constexpr index_t c_offset =
504 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
505 c_thread_buf(Number<c_offset>{}) +=
506 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
507 .template AsType<AccDataType>()[Number<t>{}] *
508 type_convert<AccDataType>(c_scale_thread_buf[m0]);
509 });
510 });
511 });
512
513 static_for<0, MRepeat, 1>{}([&](auto m0) {
514 c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
515 });
516
518 static_for<0, KRepeat, 1>{}([&](auto k) {
519 static_for<0, MRepeat, 1>{}([&](auto m0) {
522 a_block_buf,
524 make_tuple(m0, I0, k, I0),
525 a_thread_buf);
526 });
527 static_for<0, NRepeat, 1>{}([&](auto n0) {
530 b_block_buf,
532 make_tuple(n0, I0, k, I0),
533 b_thread_buf);
534 });
535 });
536
538 __builtin_amdgcn_sched_barrier(0);
539
540 static_for<0, MRepeat, 1>{}([&](auto m0) {
541 a_scale_thread_copy.Run(a_scale_grid_desc,
542 a_scale_grid_buf,
543 a_scale_thread_desc,
544 make_tuple(m0, I0),
545 a_scale_thread_buf);
546 a_scale_thread_copy.MoveSrcSliceWindow(
547 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
548 });
549
550 if constexpr(NumKBlockPerScale == 1)
551 {
552 a_scale_thread_copy.MoveSrcSliceWindow(
553 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
554 }
555 else
556 {
557 a_scale_thread_copy.MoveSrcSliceWindow(
558 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
559 }
560
561 b_scale_thread_copy.Run(b_scale_grid_desc,
562 b_scale_grid_buf,
563 b_scale_thread_desc,
564 make_tuple(I0, I0),
565 b_scale_thread_buf);
566
567 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
568
569 i += 1;
570 } while(i < (num_loop - 1));
571 }
572
573 // tail
574 if constexpr(TailNum == TailNumber::Full)
575 {
576 static_for<0, MRepeat, 1>{}([&](auto m0) {
577 static_for<0, NRepeat, 1>{}([&](auto n0) {
578 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
579 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
580 .template AsType<AccDataType>()(Number<t>{}) = 0;
581 });
582 static_for<0, KRepeat, 1>{}([&](auto k0) {
585
586 static_for<0, KPack, 1>{}([&](auto ik) {
587 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
588 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
589 make_tuple(m0, I0, k0, ik))>{}];
590 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
591 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
592 make_tuple(n0, I0, k0, ik))>{}];
593 });
594
595 using mfma_input_type =
596 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
597
598 xdlops_gemm.template Run<>(
599 a_thread_vec.template AsType<mfma_input_type>(),
600 b_thread_vec.template AsType<mfma_input_type>(),
601 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
602 });
603 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
604 constexpr index_t c_offset =
605 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
606 c_thread_buf(Number<c_offset>{}) +=
607 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
608 .template AsType<AccDataType>()[Number<t>{}] *
609 type_convert<AccDataType>(c_scale_thread_buf[m0]);
610 });
611 });
612 });
613 __builtin_amdgcn_sched_barrier(0);
614 }
615 }
616
617 protected:
618 using Base::a_thread_copy_;
619 using Base::a_thread_desc_;
620 using Base::b_thread_copy_;
621 using Base::b_thread_desc_;
622 using Base::c_thread_desc_;
623};
624
625} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp:295
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp:103
Definition blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp:37
Definition static_buffer.hpp:75
Definition functional2.hpp:33
Definition dtype_vector.hpp:10