blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 3
12// LocalPreFillStages: 2
13// LocalPreFetchStages: 2
14// LocalSharedMemoryBuffer: 2
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::KGroup;
126 using Base::KRepeat;
127 using Base::xdlops_gemm;
128 using typename Base::HotLoopInstList;
129
142
143 using Base::AMmaKStride;
144 using Base::BMmaKStride;
145 using Base::WaveSize;
146
147 static constexpr index_t PrefetchStages = 3;
148 static constexpr index_t PrefillStages = 2;
149 static constexpr index_t GlobalBufferNum = 2;
150
151 template <typename TileDesc_M0_M1_M2_K>
152 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
153 {
154 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
155 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
156 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
157 constexpr index_t K2 = KPack / KGroup;
158 constexpr index_t K1 = WaveSize / NPerXDL;
159 constexpr index_t K0 = KRepeat * KGroup;
160
162 TileDesc_M0_M1_M2_K{},
170 }
171
172 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
174
175 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
176 {
177 return num_loop > PrefetchStages;
178 }
179
180 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
181 {
182
183 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
184 }
185
186 __device__ static constexpr auto HotLoopScheduler()
187 {
188 // constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
189 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
190 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
191
192 // B global + A local
193 static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) {
194 ignore = i;
195 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
196 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B
197 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
198 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A
199 });
200
201 static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) {
202 ignore = i;
203 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
204 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B
205 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A
206 });
207
208 // A global
210 ignore = i;
211 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
212 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
213 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
214 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
215 });
216
217 // A local
218 // static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
219 // ignore = i;
220 // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
221 // __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
222 // });
223 }
224
225 template <bool HasMainLoop,
226 TailNumber TailNum,
227 typename AGridDesc,
228 typename ABlockDesc,
229 typename ABlockTransfer,
230 typename AGridBuffer,
231 typename ABlockBuffer,
232 typename ABlockTransferStep,
233 typename BGridDesc,
234 typename BBlockTransfer,
235 typename BGridBuffer,
236 typename BBlockBuffer,
237 typename BBlockTransferStep,
238 typename CThreadBuffer>
239 __device__ void Run(const AGridDesc& a_grid_desc,
240 const ABlockDesc& a_block_desc,
241 ABlockTransfer& a_blockwise_copy,
242 const AGridBuffer& a_grid_buf,
243 ABlockBuffer& a_block_buf,
244 const ABlockTransferStep& a_block_copy_step,
245 const BGridDesc& b_grid_desc,
246 BBlockTransfer& b_blockwise_copy,
247 const BGridBuffer& b_grid_buf,
248 BBlockBuffer& b_block_buf,
249 const BBlockTransferStep& b_block_copy_step,
250 CThreadBuffer& c_thread_buf,
251 index_t num_loop) const
252 {
253 ignore = b_block_buf;
254 __builtin_amdgcn_sched_barrier(0);
256 a_thread_desc_.GetElementSpaceSize());
258 b_thread_desc_.GetElementSpaceSize());
259
260 StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
261 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
262 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
263
264 // Global prefetch A1, B1
265 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
266 b_blockwise_copy.Run(b_grid_desc,
267 b_grid_buf,
269 b_block_origin_idx,
270 b_thread_bufs(I0));
271
272 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
273 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
274
275 // Local prefill A1
276 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0);
277
278 // Global prefetch A2
279 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
280 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
281
282 // Local prefetch A1
284 static_for<0, MRepeat, 1>{}([&](auto m0) {
285 static_for<0, KRepeat, 1>{}([&](auto k0) {
286 static_for<0, KGroup, 1>{}([&](auto kg0) {
289 a_block_buf.At(I0),
291 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
292 a_thread_bufs(I0));
293 });
294 });
295 });
296
297 // Local prefill A2
298 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
299
300 // // Global prefetch A3
301 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
302 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
303
304 // Initialize C
305 c_thread_buf.Clear();
306
307 __builtin_amdgcn_sched_barrier(0);
308
309 // main body
310 if constexpr(HasMainLoop)
311 {
312 index_t i = 0;
313 do
314 {
315 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
317
318 b_blockwise_copy.Run(b_grid_desc,
319 b_grid_buf,
321 b_block_origin_idx,
322 b_thread_bufs(local_read_buf));
323 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
324
325 // main loop A matrix prefetch
326 static_for<0, MRepeat, 1>{}([&](auto m0) {
327 static_for<0, KRepeat, 1>{}([&](auto k0) {
328 static_for<0, KGroup, 1>{}([&](auto kg0) {
329 a_thread_copy_.Run(
332 a_block_buf.At(local_read_buf),
334 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
335 a_thread_bufs(local_read_buf));
336 });
337 });
338 });
339
340 a_blockwise_copy.RunWrite(
341 a_block_desc, a_block_buf.At(mfma_reg_buf), mfma_reg_buf);
342
343 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
344 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
345
346 static_for<0, MRepeat, 1>{}([&](auto m0) {
347 static_for<0, NRepeat, 1>{}([&](auto n0) {
348 static_for<0, KRepeat, 1>{}([&](auto k0) {
351
352 static_for<0, KPack, 1>{}([&](auto ik) {
353 a_thread_vec.template AsType<ComputeDataType>()(ik) =
354 a_thread_bufs[mfma_reg_buf]
355 [Number<a_thread_desc_.CalculateOffset(
356 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
357 b_thread_vec.template AsType<ComputeDataType>()(ik) =
358 b_thread_bufs[mfma_reg_buf]
359 [Number<b_thread_desc_.CalculateOffset(
360 make_tuple(n0, I0, k0, ik))>{}];
361 });
362
363 using mfma_input_type =
364 typename vector_type<ComputeDataType,
365 xdlops_gemm.K1PerXdlops>::type;
366
367 constexpr index_t c_offset =
368 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
369
370 xdlops_gemm.Run(
371 a_thread_vec.template AsType<mfma_input_type>(),
372 b_thread_vec.template AsType<mfma_input_type>(),
373 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
374 });
375 });
376 });
377
379 __builtin_amdgcn_sched_barrier(0);
380 };
381
382 LoopFunc(I0, I1);
383 LoopFunc(I1, I0);
384
385 i += 2;
386 } while(i < (num_loop - 3));
387 }
388 // tail
389
390 auto ReadWriteCompFunc = [&](auto mfma_reg, auto local_read_reg) {
392
393 b_blockwise_copy.Run(b_grid_desc,
394 b_grid_buf,
396 b_block_origin_idx,
397 b_thread_bufs(local_read_reg));
398 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
399
400 // tail prefetch A
401 static_for<0, MRepeat, 1>{}([&](auto m0) {
402 static_for<0, KRepeat, 1>{}([&](auto k0) {
403 static_for<0, KGroup, 1>{}([&](auto kg0) {
404 a_thread_copy_.Run(
407 a_block_buf.At(local_read_reg),
409 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
410 a_thread_bufs(local_read_reg));
411 });
412 });
413 });
414
415 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg);
416
417 static_for<0, MRepeat, 1>{}([&](auto m0) {
418 static_for<0, NRepeat, 1>{}([&](auto n0) {
419 static_for<0, KRepeat, 1>{}([&](auto k0) {
422
423 static_for<0, KPack, 1>{}([&](auto ik) {
424 a_thread_vec.template AsType<ComputeDataType>()(ik) =
425 a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
426 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
427 b_thread_vec.template AsType<ComputeDataType>()(ik) =
428 b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
429 make_tuple(n0, I0, k0, ik))>{}];
430 });
431
432 using mfma_input_type =
433 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
434
435 constexpr index_t c_offset =
436 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
437
438 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
439 b_thread_vec.template AsType<mfma_input_type>(),
440 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
441 });
442 });
443 });
444
446 __builtin_amdgcn_sched_barrier(0);
447 };
448
449 auto ReadCompFunc = [&](auto mfma_reg, auto local_read_reg) {
451
452 b_blockwise_copy.Run(b_grid_desc,
453 b_grid_buf,
455 b_block_origin_idx,
456 b_thread_bufs(local_read_reg));
457
458 static_for<0, MRepeat, 1>{}([&](auto m0) {
459 static_for<0, KRepeat, 1>{}([&](auto k0) {
460 static_for<0, KGroup, 1>{}([&](auto kg0) {
461 a_thread_copy_.Run(
464 a_block_buf.At(local_read_reg),
466 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
467 a_thread_bufs(local_read_reg));
468 });
469 });
470 });
471
472 static_for<0, MRepeat, 1>{}([&](auto m0) {
473 static_for<0, NRepeat, 1>{}([&](auto n0) {
474 static_for<0, KRepeat, 1>{}([&](auto k0) {
477
478 static_for<0, KPack, 1>{}([&](auto ik) {
479 a_thread_vec.template AsType<ComputeDataType>()(ik) =
480 a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
481 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
482 b_thread_vec.template AsType<ComputeDataType>()(ik) =
483 b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
484 make_tuple(n0, I0, k0, ik))>{}];
485 });
486
487 using mfma_input_type =
488 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
489
490 constexpr index_t c_offset =
491 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
492
493 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
494 b_thread_vec.template AsType<mfma_input_type>(),
495 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
496 });
497 });
498 });
499
501 __builtin_amdgcn_sched_barrier(0);
502 };
503
504 auto CompFunc = [&](auto mfma_reg) {
505 static_for<0, MRepeat, 1>{}([&](auto m0) {
506 static_for<0, NRepeat, 1>{}([&](auto n0) {
507 static_for<0, KRepeat, 1>{}([&](auto k0) {
510
511 static_for<0, KPack, 1>{}([&](auto ik) {
512 a_thread_vec.template AsType<ComputeDataType>()(ik) =
513 a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
514 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
515 b_thread_vec.template AsType<ComputeDataType>()(ik) =
516 b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
517 make_tuple(n0, I0, k0, ik))>{}];
518 });
519
520 using mfma_input_type =
521 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
522
523 constexpr index_t c_offset =
524 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
525
526 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
527 b_thread_vec.template AsType<mfma_input_type>(),
528 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
529 });
530 });
531 });
532 };
533
534 if constexpr(TailNum == TailNumber::Even)
535 {
536 ReadCompFunc(I0, I1);
537 CompFunc(I1);
538 }
539 else if constexpr(TailNum == TailNumber::Odd)
540 {
541 ReadWriteCompFunc(I0, I1);
542 ReadCompFunc(I1, I0);
543 CompFunc(I0);
544 }
545 }
546
547 protected:
548 // MRepeat MWave MLane KRepeat KLane KPack
549 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
552
554 ComputeDataType,
556 decltype(a_thread_desc_),
557 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
559 5,
560 A_K1,
561 A_K1>;
562
564
567
568 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
569
571};
572
573} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
static constexpr index_t KGroup
Definition blockwise_gemm_pipeline_xdlops_base.hpp:67
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp:102
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp:553
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp:239
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10