blockwise_gemm_pipeline_xdlops.hpp Source File

blockwise_gemm_pipeline_xdlops.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops.hpp Source File
blockwise_gemm_pipeline_xdlops.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12// Double LDS buffer
13// Prefetech 2 stage
14// Local prefetch 1 stage
15
16namespace ck {
17
18template <index_t BlockSize,
19 index_t MPerBlock,
20 index_t NPerBlock,
21 index_t KPerBlock,
22 index_t ABufferLoadWidth,
23 index_t BBufferLoadWidth,
24 index_t ALDSWriteWidth,
25 index_t BLDSWriteWidth,
26 index_t ALDSReadWidth,
27 index_t BLDSReadWidth,
28 index_t MRepeat,
29 index_t NRepeat,
30 index_t MPerXDL,
31 index_t NPerXDL,
32 index_t KPerXDL>
34{
35 static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
36 static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
37 static constexpr index_t WaveSize = BlockSize / (WaveNumM * WaveNumN);
38
40 MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
42 NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
43
44 static constexpr index_t A_LDS_Write_Inst_Num =
45 MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
46 static constexpr index_t B_LDS_Write_Inst_Num =
47 NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
48
49 static constexpr index_t A_LDS_Read_Inst_Num =
50 WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
51 static constexpr index_t B_LDS_Read_Inst_Num =
52 WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
53
54 static constexpr index_t C_MFMA_Inst_Num =
55 MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
56
57 static constexpr auto Print()
58 {
59 printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
60 BlockSize,
62 MPerBlock,
63 NPerBlock,
64 KPerBlock,
65 MPerXDL,
66 NPerXDL,
67 KPerXDL);
68
69 printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
70 "%d, %d\n C MFMA inst: %d\n",
78 }
79};
80
81template <
82 index_t BlockSize,
83 typename FloatAB,
84 typename FloatAcc,
85 typename ATileDesc,
86 typename BTileDesc,
87 typename AMmaTileDesc,
88 typename BMmaTileDesc,
89 index_t MPerBlock,
90 index_t NPerBlock,
91 index_t KPerBlock,
92 index_t MPerXDL,
93 index_t NPerXDL,
94 index_t MRepeat,
95 index_t NRepeat,
96 index_t KPack,
97 bool TransposeC = false,
98 index_t AMmaKStride =
99 KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
100 index_t BMmaKStride =
103{
104 static constexpr auto I0 = Number<0>{};
105 static constexpr auto I1 = Number<1>{};
106 static constexpr auto I2 = Number<2>{};
107 static constexpr auto I3 = Number<3>{};
108
110
111 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
112 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
113 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
114
115 static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
116 static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
117 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
118 static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
119
122
123 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
124 static constexpr index_t KRepeat = KPerThread / KPack;
125
127 MPerBlock,
128 NPerBlock,
129 KPerBlock,
130 A_K1,
131 B_K1,
132 A_K1,
133 B_K1,
134 KPack,
135 KPack,
136 MRepeat,
137 NRepeat,
138 MPerXDL,
139 NPerXDL,
140 xdlops_gemm.KPerXdlops>;
141
142 static_assert(KPerThread % KPack == 0,
143 "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
144
146 FloatAcc,
147 MRepeat * NRepeat,
148 xdlops_gemm.GetRegSizePerXdlops(),
149 true>
151
152 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
153
154 __device__ static auto GetWaveIdx()
155 {
156 const index_t thread_id = ThisThreadBlock::GetThreadId();
157
158 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
162
163 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
164 }
165
166 __device__ static auto CalculateAThreadOriginDataIndex()
167 {
168 const auto wave_idx = GetWaveIdx();
169
170 const auto waveId_m = wave_idx[I0];
171
172 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
173
174 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
175 }
176
177 __device__ static auto CalculateBThreadOriginDataIndex()
178 {
179 const auto wave_idx = GetWaveIdx();
180
181 const auto waveId_n = wave_idx[I1];
182
183 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
184
185 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
186 }
187
188 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
189 __device__ static auto
191 {
192 const auto wave_idx = GetWaveIdx();
193
194 const auto waveId_m = wave_idx[I0];
195 const auto waveId_n = wave_idx[I1];
196
197 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
198
199 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
203
204 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
208
209 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
210 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
211 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
212 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
213
214 return make_tuple(c_thread_m, c_thread_n);
215 }
216
217 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
218 __device__ static auto
220 {
221 const auto wave_idx = GetWaveIdx();
222
223 const auto waveId_m = wave_idx[I0];
224 const auto waveId_n = wave_idx[I1];
225
226 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
227
228 return make_tuple(
229 m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
230 }
231
233
234 __host__ __device__
237 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
238 {
239#if defined(__HIP_DEVICE_COMPILE__)
240 static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
241 "wrong! Desc should be known at compile-time");
242
244 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
245
246 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
247 "wrong!");
248#endif
249 // HotLoopInstList::Print();
250 }
251
252 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
253 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
254 {
255 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
256
257 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
258 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
259 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
260 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
261
263 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
264 }
265
266 // XDL output supporting C_xdl = A_xdl * B_xdl
267 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
268 {
269 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
270
271 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
272 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
273 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
274 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
275
277 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
278 }
279
280 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
281 {
282 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
283
284 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
285 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
286 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
287 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
288
290 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
291 }
292
293 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
294 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
295 {
296 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
302 Number<NPerXDL>{}));
303
304 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
305 }
306
307 // XDL output supporting C_xdl = A_xdl * B_xdl
308 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
309 {
310 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
316 Number<NPerXDL>{}));
317
318 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
319 }
320
321 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
322 {
323 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
330 Number<NPerXDL>{}));
331
332 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
333 c_block_desc_g_m0_n0_m1_n1_m2_n2);
334 }
335
336 template <typename CGridDesc_M_N>
337 __host__ __device__ static constexpr auto
338 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
339 {
340 const auto M = c_grid_desc_m_n.GetLength(I0);
341 const auto N = c_grid_desc_m_n.GetLength(I1);
342
343 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
344 c_grid_desc_m_n,
345 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
346 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
349
350 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
351 }
352
353 template <typename CGridDesc_G_M_N>
354 __host__ __device__ static constexpr auto
355 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
356 {
357 const auto G = c_grid_desc_g_m_n.GetLength(I0);
358 const auto M = c_grid_desc_g_m_n.GetLength(I1);
359 const auto N = c_grid_desc_g_m_n.GetLength(I2);
360
361 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
362 c_grid_desc_g_m_n,
364 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
365 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
368
369 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
370 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
371 }
372
373 __device__ static constexpr auto HotLoopScheduler()
374 {
375 // schedule
376 constexpr auto num_ds_read_inst =
378 constexpr auto num_ds_write_inst =
380 ;
381 constexpr auto num_buffer_load_inst =
383 ;
384 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
385
386 constexpr auto num_issue = num_buffer_load_inst;
387
388 static_for<0, num_issue, 1>{}([&](auto i) {
389 ignore = i;
390 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
391 __builtin_amdgcn_sched_group_barrier(
392 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
393 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
394 __builtin_amdgcn_sched_group_barrier(
395 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
396 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
397 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
398 __builtin_amdgcn_sched_group_barrier(
399 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
400 });
401 }
402
403 template <index_t stage>
404 __device__ static constexpr auto TailScheduler()
405 {
406 }
407
408 template <>
409 __device__ constexpr auto TailScheduler<1>()
410 {
411 // schedule
412 constexpr auto num_ds_read_inst =
414 constexpr auto num_ds_write_inst =
416 ;
417 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
418
419 constexpr auto num_issue = num_ds_write_inst;
420
421 static_for<0, num_issue, 1>{}([&](auto i) {
422 ignore = i;
423 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
424 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
425 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
426 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
427 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
428 __builtin_amdgcn_sched_group_barrier(
429 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read
430 __builtin_amdgcn_sched_group_barrier(
431 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA
432 });
433 }
434
435 template <>
436 __device__ constexpr auto TailScheduler<2>()
437 {
438 // schedule
439 constexpr auto num_ds_read_inst =
441 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
442
443 constexpr auto num_issue = num_ds_read_inst;
444
445 static_for<0, num_issue, 1>{}([&](auto i) {
446 ignore = i;
447 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
448 __builtin_amdgcn_sched_group_barrier(
449 0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA
450 });
451 }
452
453 static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
454 static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
455
456 template <bool HasMainLoop,
457 index_t TailNum,
458 typename AGridDesc,
459 typename ABlockDesc,
460 typename ABlockTransfer,
461 typename AGridBuffer,
462 typename ABlockBuffer,
463 typename ABlockTransferStep,
464 typename BGridDesc,
465 typename BBlockDesc,
466 typename BBlockTransfer,
467 typename BGridBuffer,
468 typename BBlockBuffer,
469 typename BBlockTransferStep,
470 typename CThreadBuffer>
471 __device__ void Run(const AGridDesc& a_grid_desc,
472 const ABlockDesc& a_block_desc,
473 ABlockTransfer& a_blockwise_copy,
474 const AGridBuffer& a_grid_buf,
475 ABlockBuffer& a_block_buf,
476 const ABlockTransferStep& a_block_copy_step,
477 const BGridDesc& b_grid_desc,
478 const BBlockDesc& b_block_desc,
479 BBlockTransfer& b_blockwise_copy,
480 const BGridBuffer& b_grid_buf,
481 BBlockBuffer& b_block_buf,
482 const BBlockTransferStep& b_block_copy_step,
483 CThreadBuffer& c_thread_buf,
484 index_t num_loop) const
485 {
486 __builtin_amdgcn_sched_barrier(0);
488 a_thread_desc_.GetElementSpaceSize());
490 b_thread_desc_.GetElementSpaceSize());
491
492 StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
493 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
494 // Inst List:
495 // ds_read_b128: 16
496 // ds_write_b128: 8
497 // buffer_load_dwordx4: 16
498 // v_mfma: 0
499 // -------------------------------------------------------------------------------------------
500
501 // Global prefetch 1th, Fill Ping LDS
502 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
503 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
504
505 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
506 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
507
508 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
509 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
510
511 // Local prefetch 1th, Fill Ping Reg
513 static_for<0, KRepeat, 1>{}([&](auto k) {
514 static_for<0, MRepeat, 1>{}([&](auto m0) {
517 a_block_buf.At(I0),
519 make_tuple(m0, I0, k, I0),
520 a_thread_bufs(I0));
521 static_for<0, NRepeat, 1>{}([&](auto n0) {
524 b_block_buf.At(I0),
526 make_tuple(n0, I0, k, I0),
527 b_thread_bufs(I0));
528 });
529 });
530 });
531
532 // Global prefetch 2th, Fill Pong LDS
533 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
534 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
535
536 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
537 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
538
539 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
540 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
541
542 // Global prefetch 3rd
543 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
544 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
545
546 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
547 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
548
549 // Initialize C
550 c_thread_buf.Clear();
551
552 // main body
553 if constexpr(HasMainLoop)
554 {
555 index_t i = 0;
556 // This hot loop has two legacy loopover, to implement the double local buffer strategy
557 do
558 {
559 // -------------------------------------------------------------------------------------------
560 using PingP1 = Number<0>;
561 using PongP1 = Number<1>;
562 // MFMA: Ping Reg
563 // DS_WRITE: To Ping LDS
564 // DS_READ: Pong LDS to Pong Reg
566
567 static_for<0, KRepeat, 1>{}([&](auto k) {
568 static_for<0, MRepeat, 1>{}([&](auto m0) {
571 a_block_buf.At(PongP1{}),
573 make_tuple(m0, I0, k, I0),
574 a_thread_bufs(PongP1{}));
575 static_for<0, NRepeat, 1>{}([&](auto n0) {
578 b_block_buf.At(PongP1{}),
580 make_tuple(n0, I0, k, I0),
581 b_thread_bufs(PongP1{}));
582 });
583 });
584 });
585
586 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
587 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
588
589 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
590 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
591
592 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
593 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
594
595 static_for<0, KRepeat, 1>{}([&](auto k0) {
596 static_for<0, MRepeat, 1>{}([&](auto m0) {
597 static_for<0, NRepeat, 1>{}([&](auto n0) {
598 vector_type<FloatAB, KPack> a_thread_vec;
599 vector_type<FloatAB, KPack> b_thread_vec;
600
601 static_for<0, KPack, 1>{}([&](auto ik) {
602 a_thread_vec.template AsType<FloatAB>()(ik) =
603 a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
604 make_tuple(m0, I0, k0, ik))>{}];
605 b_thread_vec.template AsType<FloatAB>()(ik) =
606 b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
607 make_tuple(n0, I0, k0, ik))>{}];
608 });
609
610 using mfma_input_type =
611 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
612
613 constexpr index_t c_offset =
614 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
615
616 xdlops_gemm.Run(
617 a_thread_vec.template AsType<mfma_input_type>(),
618 b_thread_vec.template AsType<mfma_input_type>(),
619 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
620 });
621 });
622 });
623
625 __builtin_amdgcn_sched_barrier(0);
626
627 // -------------------------------------------------------------------------------------------
628 using PingP2 = Number<1>;
629 using PongP2 = Number<0>;
630 // MFMA: Pong Reg
631 // DS_WRITE: To Pong LDS
632 // DS_READ: Ping LDS to Ping Reg
634
635 static_for<0, KRepeat, 1>{}([&](auto k) {
636 static_for<0, MRepeat, 1>{}([&](auto m0) {
639 a_block_buf.At(PongP2{}),
641 make_tuple(m0, I0, k, I0),
642 a_thread_bufs(PongP2{}));
643 static_for<0, NRepeat, 1>{}([&](auto n0) {
646 b_block_buf.At(PongP2{}),
648 make_tuple(n0, I0, k, I0),
649 b_thread_bufs(PongP2{}));
650 });
651 });
652 });
653
654 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
655 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
656
657 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
658 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
659
660 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
661 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
662
663 static_for<0, KRepeat, 1>{}([&](auto k0) {
664 static_for<0, MRepeat, 1>{}([&](auto m0) {
665 static_for<0, NRepeat, 1>{}([&](auto n0) {
666 vector_type<FloatAB, KPack> a_thread_vec;
667 vector_type<FloatAB, KPack> b_thread_vec;
668
669 static_for<0, KPack, 1>{}([&](auto ik) {
670 a_thread_vec.template AsType<FloatAB>()(ik) =
671 a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
672 make_tuple(m0, I0, k0, ik))>{}];
673 b_thread_vec.template AsType<FloatAB>()(ik) =
674 b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
675 make_tuple(n0, I0, k0, ik))>{}];
676 });
677
678 using mfma_input_type =
679 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
680
681 constexpr index_t c_offset =
682 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
683
684 xdlops_gemm.Run(
685 a_thread_vec.template AsType<mfma_input_type>(),
686 b_thread_vec.template AsType<mfma_input_type>(),
687 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
688 });
689 });
690 });
691
693 __builtin_amdgcn_sched_barrier(0);
694
695 i += 2;
696 } while(i < (num_loop - 3));
697 }
698
699 // tail
700 if constexpr(TailNum == 3)
701 {
702 using PingP1 = Number<0>;
703 using PongP1 = Number<1>;
704 // MFMA: Ping Reg
705 // DS_WRITE: To Ping LDS
706 // DS_READ: Pong LDS to Pong Reg
708
709 static_for<0, KRepeat, 1>{}([&](auto k) {
710 static_for<0, MRepeat, 1>{}([&](auto m0) {
713 a_block_buf.At(PongP1{}),
715 make_tuple(m0, I0, k, I0),
716 a_thread_bufs(PongP1{}));
717 static_for<0, NRepeat, 1>{}([&](auto n0) {
720 b_block_buf.At(PongP1{}),
722 make_tuple(n0, I0, k, I0),
723 b_thread_bufs(PongP1{}));
724 });
725 });
726 });
727
728 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
729 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
730
731 static_for<0, KRepeat, 1>{}([&](auto k0) {
732 static_for<0, MRepeat, 1>{}([&](auto m0) {
733 static_for<0, NRepeat, 1>{}([&](auto n0) {
734 vector_type<FloatAB, KPack> a_thread_vec;
735 vector_type<FloatAB, KPack> b_thread_vec;
736
737 static_for<0, KPack, 1>{}([&](auto ik) {
738 a_thread_vec.template AsType<FloatAB>()(ik) =
739 a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
740 make_tuple(m0, I0, k0, ik))>{}];
741 b_thread_vec.template AsType<FloatAB>()(ik) =
742 b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
743 make_tuple(n0, I0, k0, ik))>{}];
744 });
745
746 using mfma_input_type =
747 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
748
749 constexpr index_t c_offset =
750 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
751
752 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
753 b_thread_vec.template AsType<mfma_input_type>(),
754 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
755 });
756 });
757 });
758
760 __builtin_amdgcn_sched_barrier(0);
761
762 // -------------------------------------------------------------------------------------------
763 using PingP2 = Number<1>;
764 using PongP2 = Number<0>;
765 // MFMA: Pong Reg
766 // DS_WRITE: To Pong LDS
767 // DS_READ: Ping LDS to Ping Reg
769
770 static_for<0, KRepeat, 1>{}([&](auto k) {
771 static_for<0, MRepeat, 1>{}([&](auto m0) {
774 a_block_buf.At(PongP2{}),
776 make_tuple(m0, I0, k, I0),
777 a_thread_bufs(PongP2{}));
778 static_for<0, NRepeat, 1>{}([&](auto n0) {
781 b_block_buf.At(PongP2{}),
783 make_tuple(n0, I0, k, I0),
784 b_thread_bufs(PongP2{}));
785 });
786 });
787 });
788
789 static_for<0, KRepeat, 1>{}([&](auto k0) {
790 static_for<0, MRepeat, 1>{}([&](auto m0) {
791 static_for<0, NRepeat, 1>{}([&](auto n0) {
792 vector_type<FloatAB, KPack> a_thread_vec;
793 vector_type<FloatAB, KPack> b_thread_vec;
794
795 static_for<0, KPack, 1>{}([&](auto ik) {
796 a_thread_vec.template AsType<FloatAB>()(ik) =
797 a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
798 make_tuple(m0, I0, k0, ik))>{}];
799 b_thread_vec.template AsType<FloatAB>()(ik) =
800 b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
801 make_tuple(n0, I0, k0, ik))>{}];
802 });
803
804 using mfma_input_type =
805 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
806
807 constexpr index_t c_offset =
808 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
809
810 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
811 b_thread_vec.template AsType<mfma_input_type>(),
812 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
813 });
814 });
815 });
816
818 __builtin_amdgcn_sched_barrier(0);
819
820 static_for<0, KRepeat, 1>{}([&](auto k) {
821 static_for<0, MRepeat, 1>{}([&](auto m0) {
822 static_for<0, NRepeat, 1>{}([&](auto n0) {
823 vector_type<FloatAB, KPack> a_thread_vec;
824 vector_type<FloatAB, KPack> b_thread_vec;
825
826 static_for<0, KPack, 1>{}([&](auto ik) {
827 a_thread_vec.template AsType<FloatAB>()(ik) =
828 a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
829 make_tuple(m0, I0, k, ik))>{}];
830 b_thread_vec.template AsType<FloatAB>()(ik) =
831 b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
832 make_tuple(n0, I0, k, ik))>{}];
833 });
834
835 using mfma_input_type =
836 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
837
838 constexpr index_t c_offset =
839 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
840
841 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
842 b_thread_vec.template AsType<mfma_input_type>(),
843 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
844 });
845 });
846 });
847
848 // 64 v_mfma
849 __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
850 __builtin_amdgcn_sched_barrier(0);
851 }
852 else if constexpr(TailNum == 2)
853 {
854 using PingP1 = Number<0>;
855 using PongP1 = Number<1>;
856 // MFMA: Ping Reg
857 // DS_WRITE: To Ping LDS
858 // DS_READ: Pong LDS to Pong Reg
860
861 static_for<0, KRepeat, 1>{}([&](auto k) {
862 static_for<0, MRepeat, 1>{}([&](auto m0) {
865 a_block_buf.At(PongP1{}),
867 make_tuple(m0, I0, k, I0),
868 a_thread_bufs(PongP1{}));
869 static_for<0, NRepeat, 1>{}([&](auto n0) {
872 b_block_buf.At(PongP1{}),
874 make_tuple(n0, I0, k, I0),
875 b_thread_bufs(PongP1{}));
876 });
877 });
878 });
879
880 static_for<0, KRepeat, 1>{}([&](auto k0) {
881 static_for<0, MRepeat, 1>{}([&](auto m0) {
882 static_for<0, NRepeat, 1>{}([&](auto n0) {
883 vector_type<FloatAB, KPack> a_thread_vec;
884 vector_type<FloatAB, KPack> b_thread_vec;
885
886 static_for<0, KPack, 1>{}([&](auto ik) {
887 a_thread_vec.template AsType<FloatAB>()(ik) =
888 a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
889 make_tuple(m0, I0, k0, ik))>{}];
890 b_thread_vec.template AsType<FloatAB>()(ik) =
891 b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
892 make_tuple(n0, I0, k0, ik))>{}];
893 });
894
895 using mfma_input_type =
896 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
897
898 constexpr index_t c_offset =
899 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
900
901 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
902 b_thread_vec.template AsType<mfma_input_type>(),
903 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
904 });
905 });
906 });
907
909 __builtin_amdgcn_sched_barrier(0);
910
911 // -------------------------------------------------------------------------------------------
912 using PingP2 = Number<1>;
913 // MFMA: Pong Reg
914 // DS_WRITE: To Pong LDS
915 // DS_READ: Ping LDS to Ping Reg
916
917 static_for<0, KRepeat, 1>{}([&](auto k0) {
918 static_for<0, MRepeat, 1>{}([&](auto m0) {
919 static_for<0, NRepeat, 1>{}([&](auto n0) {
920 vector_type<FloatAB, KPack> a_thread_vec;
921 vector_type<FloatAB, KPack> b_thread_vec;
922
923 static_for<0, KPack, 1>{}([&](auto ik) {
924 a_thread_vec.template AsType<FloatAB>()(ik) =
925 a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
926 make_tuple(m0, I0, k0, ik))>{}];
927 b_thread_vec.template AsType<FloatAB>()(ik) =
928 b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
929 make_tuple(n0, I0, k0, ik))>{}];
930 });
931
932 using mfma_input_type =
933 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
934
935 constexpr index_t c_offset =
936 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
937
938 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
939 b_thread_vec.template AsType<mfma_input_type>(),
940 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
941 });
942 });
943 });
944
945 // 64 v_mfma
946 __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
947 __builtin_amdgcn_sched_barrier(0);
948 }
949 }
950
951 protected:
952 // M1, N1 as double buffer index
953 // Read buffer + Compute buffer
954 // A[M0, M1, M2, KPack]
959
960 // B[N0, N1, N2, KPack]
965
966 // C[M, N, NumRegXdlops]
968 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
969
971 FloatAB,
972 decltype(a_block_desc_m0_m1_m2_k),
973 decltype(a_thread_desc_),
976 3,
977 A_K1,
978 A_K1>;
979
981 FloatAB,
982 decltype(b_block_desc_n0_n1_n2_k),
983 decltype(b_thread_desc_),
986 3,
987 B_K1,
988 B_K1>;
989
992};
993
994} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr auto Print()
Definition blockwise_gemm_pipeline_xdlops.hpp:57
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops.hpp:105
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops.hpp:980
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops.hpp:111
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops.hpp:338
static __device__ constexpr auto HotLoopScheduler()
Definition blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops.hpp:117
static constexpr index_t A_K0
Definition blockwise_gemm_pipeline_xdlops.hpp:115
static constexpr auto b_thread_desc_
Definition blockwise_gemm_pipeline_xdlops.hpp:961
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:321
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:267
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops.hpp:152
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops.hpp:253
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops.hpp:113
static __device__ constexpr auto TailScheduler()
Definition blockwise_gemm_pipeline_xdlops.hpp:404
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops.hpp:967
BThreadCopy b_thread_copy_
Definition blockwise_gemm_pipeline_xdlops.hpp:991
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops.hpp:970
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition blockwise_gemm_pipeline_xdlops.hpp:232
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops.hpp:453
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops.hpp:355
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops.hpp:177
AThreadCopy a_thread_copy_
Definition blockwise_gemm_pipeline_xdlops.hpp:990
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops.hpp:294
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:308
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_pipeline_xdlops.hpp:109
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops.hpp:124
static constexpr auto I3
Definition blockwise_gemm_pipeline_xdlops.hpp:107
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops.hpp:118
static constexpr auto I2
Definition blockwise_gemm_pipeline_xdlops.hpp:106
static constexpr index_t B_K0
Definition blockwise_gemm_pipeline_xdlops.hpp:116
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops.hpp:166
static constexpr index_t KPerThread
Definition blockwise_gemm_pipeline_xdlops.hpp:123
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_pipeline_xdlops.hpp:154
static constexpr auto a_thread_desc_
Definition blockwise_gemm_pipeline_xdlops.hpp:955
static constexpr index_t NWaves
Definition blockwise_gemm_pipeline_xdlops.hpp:112
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_pipeline_xdlops.hpp:150
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops.hpp:120
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops.hpp:219
BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, A_K1, B_K1, A_K1, B_K1, KPack, KPack, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops.hpp:126
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops.hpp:471
__host__ __device__ BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_pipeline_xdlops.hpp:235
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops.hpp:190
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition xdlops_gemm.hpp:1821
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
Definition functional2.hpp:33
Definition dtype_vector.hpp:10