blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::I0;
120 using Base::I1;
121 using Base::KRepeat;
122 using Base::MWaves;
123 using Base::NWaves;
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
126 using typename Base::HotLoopInstList;
127
136 using Base::GetWaveIdx;
139
142
143 using Base::AMmaKStride;
144 using Base::APackedSize;
145 using Base::BMmaKStride;
146 using Base::BPackedSize;
147 using Base::KThreadChunk;
148
149 using Base::KXdlPack;
150 using Base::MXdlPack;
151 using Base::NXdlPack;
152
153 using AccType = typename Base::AccType;
154 using Tuple5 = typename Base::Tuple5;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 1;
161
162 static constexpr auto ScalesPerKBlockSize =
163 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
164
165 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
166 static constexpr auto ScalesPerXdlopsRun =
167 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
168
169 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
170 static constexpr auto ScalesPerXdlopsRunPerThread =
171 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
172
174 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
175 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
176 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177 "A scale pack data type too large!");
178 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179 "B scale pack data type too large!");
182
183 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
184 {
185 return num_loop > PrefetchStages;
186 }
187
188 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
189 {
190 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
191 }
192
193 __device__ static constexpr auto HotLoopScheduler()
194 {
195 // A/B split schedule
196 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
197 constexpr auto num_ds_read_inst_a =
198 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
201 constexpr auto num_ds_read_inst_b =
202 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
205
206 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
207 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
208
209 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
210 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
211
212 constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
213 constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
214
215 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
216
217 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
218 constexpr auto ds_read_a_issue_cycle =
219 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
220 constexpr auto ds_read_b_issue_cycle =
221 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
222
223 constexpr auto ds_read_a_mfma_rate =
224 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
225 constexpr auto ds_read_b_mfma_rate =
226 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
227
228 constexpr auto num_dsread_a_mfma =
229 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
230 constexpr auto num_dsread_b_mfma =
231 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
232
233 // stage 1
234 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
235 constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
236 num_buffer_load_a_scale + num_buffer_load_b_scale;
237
238 constexpr auto mfma_perstage_more =
239 math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
240 constexpr auto mfma_perstage_less =
241 math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
242
243 constexpr auto mfma_stages_more =
244 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
245
246 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
247 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
248
250 if constexpr(i < mfma_stages_more)
251 {
252 static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
253 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
254 if constexpr(imfma < num_dswrite_per_issue_a)
255 {
256 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
257 }
258 });
259 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
260 }
261 else
262 {
263 static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 if constexpr(imfma < num_dswrite_per_issue_a)
266 {
267 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
268 }
269 });
270 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
271 }
272 });
273
275 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
276 {
277 static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
278 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
279 if constexpr(imfma < num_dswrite_per_issue_a)
280 {
281 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
282 }
283 });
284 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
285 }
286 else
287 {
288 static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
289 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
290 if constexpr(imfma < num_dswrite_per_issue_b)
291 {
292 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
293 }
294 });
295 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
296 }
297 });
298
300 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
301 {
302 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
303 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
304 });
305 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
306 }
307 else
308 {
309 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
310 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
311 });
312 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
313 }
314 });
315
317 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
318 num_buffer_load_a_scale) < mfma_stages_more)
319 {
320 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
321 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
322 });
323 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
324 }
325 else
326 {
327 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
328 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
329 });
330 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
331 }
332 });
333
334 // stage 2
336 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
337 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
338 ds_read_a_mfma_rate)
339 {
340 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
341 }
342 else
343 {
344 __builtin_amdgcn_sched_group_barrier(0x100,
345 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
346 ds_read_a_mfma_rate,
347 0); // DS read
348 }
349 });
350
352 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
353 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
354 ds_read_b_mfma_rate)
355 {
356 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
357 }
358 else
359 {
360 __builtin_amdgcn_sched_group_barrier(0x100,
361 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
362 ds_read_b_mfma_rate,
363 0); // DS read
364 }
365 });
366 }
367
368 template <bool HasMainLoop,
369 TailNumber TailNum,
370 typename AGridDesc,
371 typename ABlockDesc,
372 typename ABlockTransfer,
373 typename AGridBuffer,
374 typename ABlockBuffer,
375 typename ABlockTransferStep,
376 typename BGridDesc,
377 typename BBlockDesc,
378 typename BBlockTransfer,
379 typename BGridBuffer,
380 typename BBlockBuffer,
381 typename BBlockTransferStep,
382 typename CThreadBuffer,
383 typename AScaleGridBuffer,
384 typename AScaleGridDesc,
385 typename AScaleThreadTransfer,
386 typename BScaleGridBuffer,
387 typename BScaleGridDesc,
388 typename BScaleThreadTransfer>
389 __device__ void Run(
390 // ABlockCopy
391 const AGridDesc& a_grid_desc,
392 const ABlockDesc& a_block_desc,
393 ABlockTransfer& a_blockwise_copy,
394 const AGridBuffer& a_grid_buf,
395 ABlockBuffer& a_block_buf,
396 const ABlockTransferStep& a_block_copy_step,
397 // BBlockCopy
398 const BGridDesc& b_grid_desc,
399 const BBlockDesc& b_block_desc,
400 BBlockTransfer& b_blockwise_copy,
401 const BGridBuffer& b_grid_buf,
402 BBlockBuffer& b_block_buf,
403 const BBlockTransferStep& b_block_copy_step,
404 // CThread
405 CThreadBuffer& c_thread_buf,
406 // A and B scales
407 const AScaleGridDesc& a_scale_grid_desc,
408 AScaleThreadTransfer& a_scale_thread_copy,
409 const AScaleGridBuffer& a_scale_grid_buf,
410 const BScaleGridDesc& b_scale_grid_desc,
411 BScaleThreadTransfer& b_scale_thread_copy,
412 const BScaleGridBuffer& b_scale_grid_buf,
413 index_t num_loop) const
414 {
416 a_thread_desc_.GetElementSpaceSize());
418 b_thread_desc_.GetElementSpaceSize());
419
421 a_scale_thread_desc.GetElementSpaceSize());
422
424 b_scale_thread_desc.GetElementSpaceSize());
425
426 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
427 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
428
429 // Global prefetch 1
430 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
431 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
432
433 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
434 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
435
436 // Prefetch a_scales
437 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
438 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
439 a_scale_thread_copy.Run(a_scale_grid_desc,
440 a_scale_grid_buf,
442 make_tuple(m0, k0, I0),
443 a_scale_thread_bufs(I0));
444
445 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
446 make_multi_index(0, I1, 0));
447 });
448 a_scale_thread_copy.MoveSrcSliceWindow(
449 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
450 });
451
452 // restore row id and advance to the next set of scales
453 a_scale_thread_copy.MoveSrcSliceWindow(
454 a_scale_grid_desc,
455 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
456
457 // Prefetch b_scales
458 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
459 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
460 b_scale_thread_copy.Run(b_scale_grid_desc,
461 b_scale_grid_buf,
463 make_tuple(n0, k0, I0),
464 b_scale_thread_bufs(I0));
465
466 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
467 make_multi_index(0, I1, 0));
468 });
469 b_scale_thread_copy.MoveSrcSliceWindow(
470 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
471 });
472
473 // restore col id and advance to the next set of scales
474 // NWaves * NPerXDL * NRepeat == NPerBlock
475 b_scale_thread_copy.MoveSrcSliceWindow(
476 b_scale_grid_desc,
477 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
478
479 // Local prefill 1
480 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
481 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
482
483 // Global prefetch 2
484 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
485 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
486
487 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
488 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
489
490 // Local prefetch 1
492 static_for<0, KRepeat, 1>{}([&](auto k) {
493 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
494 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
495 static_for<0, MRepeat, 1>{}([&](auto m0) {
496 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
497 [&](auto chunk) {
498 constexpr auto a_k_step_chunk =
499 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
502 I0,
504 I0,
506 a_block_buf,
509 I0,
511 k,
513 a_thread_buf);
514 });
515 });
516 static_for<0, NRepeat, 1>{}([&](auto n0) {
517 // read block data in chunks to assemble correct thread vectors
518 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
519 [&](auto chunk) {
520 constexpr auto b_k_step_chunk =
521 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
524 I0,
526 I0,
528 b_block_buf,
531 I0,
533 k,
535 b_thread_buf);
536 });
537 });
538 });
539
540 // Initialize C
541 c_thread_buf.Clear();
542 __builtin_amdgcn_sched_barrier(0);
543
544 // main body
545 if constexpr(HasMainLoop)
546 {
547 // loop over k with the step KPerBlock
548 index_t i = 0;
549 do
550 {
551 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
553
554 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
555 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
556
557 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
558 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
559
560 // Prefetch a_scales
561 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
562 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
563 a_scale_thread_copy.Run(a_scale_grid_desc,
564 a_scale_grid_buf,
566 make_tuple(m0, k0, I0),
567 a_scale_thread_bufs(scale_mem_buf));
568
569 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
570 make_multi_index(0, I1, 0));
571 });
572 a_scale_thread_copy.MoveSrcSliceWindow(
573 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
574 });
575
576 // restore row id and advance to the next set of scales
577 a_scale_thread_copy.MoveSrcSliceWindow(
578 a_scale_grid_desc,
579 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
580
581 // Prefetch b_scales
582 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
583 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
584 b_scale_thread_copy.Run(b_scale_grid_desc,
585 b_scale_grid_buf,
587 make_tuple(n0, k0, I0),
588 b_scale_thread_bufs(scale_mem_buf));
589
590 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
591 make_multi_index(0, I1, 0));
592 });
593 b_scale_thread_copy.MoveSrcSliceWindow(
594 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
595 });
596
597 // restore col id and advance to the next set of scales
598 // NWaves * NPerXDL * NRepeat == NPerBlock
599 b_scale_thread_copy.MoveSrcSliceWindow(
600 b_scale_grid_desc,
601 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
602
603 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
604 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
605
606 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
607 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
608 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
609 constexpr index_t a_scale_offset =
610 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
611 constexpr index_t b_scale_offset =
612 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
613
614 static_assert(0 < ScalesPerXdlopsRunPerThread,
615 "Must have at least one scale per Xdlops "
616 "per Thread.");
617
619 a_scale_thread_vec;
621 b_scale_thread_vec;
622
623 // Pack scale_thread_buf into scale_thread_vec
625 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
626 a_scale_thread_bufs(
627 scale_comp_buf)[Number<a_scale_offset + s>{}];
628 });
629
631 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
632 b_scale_thread_bufs(
633 scale_comp_buf)[Number<b_scale_offset + s>{}];
634 });
635
636 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
637 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
638 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
639 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
640
643
644 static_for<0, KPack, 1>{}([&](auto ik) {
645 a_thread_vec.template AsType<ComputeTypeA>()(
646 ik) = a_thread_buf
647 [Number<a_thread_desc_.CalculateOffset(
648 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
649 b_thread_vec.template AsType<ComputeTypeB>()(
650 ik) = b_thread_buf
651 [Number<b_thread_desc_.CalculateOffset(
652 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
653 });
654
655 using mfma_input_type_a =
656 typename vector_type<ComputeTypeA,
657 xdlops_gemm.K1PerXdlops /
658 APackedSize>::type;
659
660 using mfma_input_type_b =
661 typename vector_type<ComputeTypeB,
662 xdlops_gemm.K1PerXdlops /
663 BPackedSize>::type;
664
665 using mfma_scale_input_type_a =
666 typename vector_type<AScaleDataType,
668 using mfma_scale_input_type_b =
669 typename vector_type<BScaleDataType,
671
672 constexpr index_t c_offset =
673 c_thread_desc_.CalculateOffset(
674 make_tuple(m0, n0, imxdl, inxdl, 0));
675
676 // MFMA accumulation
677 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
678 ikxdl * NXdlPack + inxdl>(
679 a_thread_vec.template AsType<mfma_input_type_a>(),
680 a_scale_thread_vec
681 .template AsType<mfma_scale_input_type_a>(),
682 b_thread_vec.template AsType<mfma_input_type_b>(),
683 b_scale_thread_vec
684 .template AsType<mfma_scale_input_type_b>(),
685 c_thread_buf.GetVectorTypeReference(
687 });
688 });
689 });
690 });
691 });
692 });
693
694 // k indexes mapping to threads for 32x32x64:
695 // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
696 // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
697 // k = 0 k = 1
698
699 // k indexes mapping to threads for 16x16x128:
700 // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
701 // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
702 // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
703 // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
704 // k = 0 k = 1
706 static_for<0, KRepeat, 1>{}([&](auto k) {
707 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
708 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
709 static_for<0, MRepeat, 1>{}([&](auto m0) {
710 static_for<0,
711 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
712 1>{}([&](auto chunk) {
713 constexpr auto a_k_step_chunk =
714 k_step +
715 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
718 I0,
720 I0,
722 a_block_buf,
725 I0,
727 k,
729 a_thread_buf);
730 });
731 });
732 static_for<0, NRepeat, 1>{}([&](auto n0) {
733 // read block data in chunks to assemble correct thread vectors
734 static_for<0,
735 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
736 1>{}([&](auto chunk) {
737 constexpr auto b_k_step_chunk =
738 k_step +
739 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
742 I0,
744 I0,
746 b_block_buf,
749 I0,
751 k,
753 b_thread_buf);
754 });
755 });
756 });
757
759 __builtin_amdgcn_sched_barrier(0);
760 };
761
762 LoopFunc(I0, I1);
763 LoopFunc(I1, I0);
764
765 i += 2;
766 } while(i < (num_loop - 2));
767 }
768
769 // tail
770 if constexpr(TailNum == TailNumber::Even)
771 {
772 // Prefetch a_scales
773 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
774 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
775 a_scale_thread_copy.Run(a_scale_grid_desc,
776 a_scale_grid_buf,
778 make_tuple(m0, k0, I0),
779 a_scale_thread_bufs(I1));
780
781 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
782 make_multi_index(0, I1, 0));
783 });
784 a_scale_thread_copy.MoveSrcSliceWindow(
785 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
786 });
787
788 // Prefetch b_scales
789 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
790 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
791 b_scale_thread_copy.Run(b_scale_grid_desc,
792 b_scale_grid_buf,
794 make_tuple(n0, k0, I0),
795 b_scale_thread_bufs(I1));
796
797 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
798 make_multi_index(0, I1, 0));
799 });
800 b_scale_thread_copy.MoveSrcSliceWindow(
801 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
802 });
803
805 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
806 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
807
808 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
809 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
810 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
811 constexpr index_t a_scale_offset =
812 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
813 constexpr index_t b_scale_offset =
814 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
815
816 static_assert(0 < ScalesPerXdlopsRunPerThread,
817 "Must have at least one scale per Xdlops "
818 "per Thread.");
819
822
823 // Pack scale_thread_buf into scale_thread_vec
825 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
826 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
827 });
828
830 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
831 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
832 });
833
834 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
835 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
836 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
837 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
838
841
842 static_for<0, KPack, 1>{}([&](auto ik) {
843 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
844 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
845 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
846 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
847 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
848 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
849 });
850
851 using mfma_input_type_a =
852 typename vector_type<ComputeTypeA,
853 xdlops_gemm.K1PerXdlops /
854 APackedSize>::type;
855
856 using mfma_input_type_b =
857 typename vector_type<ComputeTypeB,
858 xdlops_gemm.K1PerXdlops /
859 BPackedSize>::type;
860
861 using mfma_scale_input_type_a =
862 typename vector_type<AScaleDataType,
864 using mfma_scale_input_type_b =
865 typename vector_type<BScaleDataType,
867
868 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
869 make_tuple(m0, n0, imxdl, inxdl, 0));
870
871 // MFMA accumulation
872 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
873 ikxdl * NXdlPack + inxdl>(
874 a_thread_vec.template AsType<mfma_input_type_a>(),
875 a_scale_thread_vec
876 .template AsType<mfma_scale_input_type_a>(),
877 b_thread_vec.template AsType<mfma_input_type_b>(),
878 b_scale_thread_vec
879 .template AsType<mfma_scale_input_type_b>(),
880 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
881 });
882 });
883 });
884 });
885 });
886 });
887
889
890 static_for<0, KRepeat, 1>{}([&](auto k) {
891 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
892 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
893 static_for<0, MRepeat, 1>{}([&](auto m0) {
894 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
895 [&](auto chunk) {
896 constexpr auto a_k_step_chunk =
897 k_step +
898 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
901 I0,
903 I0,
905 a_block_buf,
908 I0,
910 k,
912 a_thread_buf);
913 });
914 });
915 static_for<0, NRepeat, 1>{}([&](auto n0) {
916 // read block data in chunks to assemble correct thread vectors
917 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
918 [&](auto chunk) {
919 constexpr auto b_k_step_chunk =
920 k_step +
921 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
924 I0,
926 I0,
928 b_block_buf,
931 I0,
933 k,
935 b_thread_buf);
936 });
937 });
938 });
939
940 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
941 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
942 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
943 constexpr index_t a_scale_offset =
944 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
945 constexpr index_t b_scale_offset =
946 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
947
948 static_assert(0 < ScalesPerXdlopsRunPerThread,
949 "Must have at least one scale per Xdlops "
950 "per Thread.");
951
954
955 // Pack scale_thread_buf into scale_thread_vec
957 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
958 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
959 });
960
962 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
963 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
964 });
965
966 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
967 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
968 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
969 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
970
973
974 static_for<0, KPack, 1>{}([&](auto ik) {
975 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
976 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
977 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
978 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
979 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
980 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
981 });
982
983 using mfma_input_type_a =
984 typename vector_type<ComputeTypeA,
985 xdlops_gemm.K1PerXdlops /
986 APackedSize>::type;
987
988 using mfma_input_type_b =
989 typename vector_type<ComputeTypeB,
990 xdlops_gemm.K1PerXdlops /
991 BPackedSize>::type;
992
993 using mfma_scale_input_type_a =
994 typename vector_type<AScaleDataType,
996 using mfma_scale_input_type_b =
997 typename vector_type<BScaleDataType,
999
1000 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1001 make_tuple(m0, n0, imxdl, inxdl, 0));
1002
1003 // MFMA accumulation
1004 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1005 ikxdl * NXdlPack + inxdl>(
1006 a_thread_vec.template AsType<mfma_input_type_a>(),
1007 a_scale_thread_vec
1008 .template AsType<mfma_scale_input_type_a>(),
1009 b_thread_vec.template AsType<mfma_input_type_b>(),
1010 b_scale_thread_vec
1011 .template AsType<mfma_scale_input_type_b>(),
1012 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1013 });
1014 });
1015 });
1016 });
1017 });
1018 });
1019 }
1020 else if constexpr(TailNum == TailNumber::Odd)
1021 {
1022 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
1023 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
1024 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
1025 constexpr index_t a_scale_offset =
1026 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
1027 constexpr index_t b_scale_offset =
1028 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
1029
1030 static_assert(0 < ScalesPerXdlopsRunPerThread,
1031 "Must have at least one scale per Xdlops "
1032 "per Thread.");
1033
1036
1037 // Pack scale_thread_buf into scale_thread_vec
1039 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1040 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1041 });
1042
1044 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1045 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1046 });
1047
1048 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1049 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1050 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1051 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1052
1055
1056 static_for<0, KPack, 1>{}([&](auto ik) {
1057 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1058 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1059 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1060 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1061 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1062 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1063 });
1064
1065 using mfma_input_type_a =
1066 typename vector_type<ComputeTypeA,
1067 xdlops_gemm.K1PerXdlops /
1068 APackedSize>::type;
1069
1070 using mfma_input_type_b =
1071 typename vector_type<ComputeTypeB,
1072 xdlops_gemm.K1PerXdlops /
1073 BPackedSize>::type;
1074
1075 using mfma_scale_input_type_a =
1076 typename vector_type<AScaleDataType,
1078 using mfma_scale_input_type_b =
1079 typename vector_type<BScaleDataType,
1081
1082 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1083 make_tuple(m0, n0, imxdl, inxdl, 0));
1084
1085 // MFMA accumulation
1086 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1087 ikxdl * NXdlPack + inxdl>(
1088 a_thread_vec.template AsType<mfma_input_type_a>(),
1089 a_scale_thread_vec
1090 .template AsType<mfma_scale_input_type_a>(),
1091 b_thread_vec.template AsType<mfma_input_type_b>(),
1092 b_scale_thread_vec
1093 .template AsType<mfma_scale_input_type_b>(),
1094 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1095 });
1096 });
1097 });
1098 });
1099 });
1100 });
1101 }
1102 }
1103
1104 // TODO: make this field protected when a_scale_thread_copy_ is moved
1105 // here
1108 Number<KRepeat / KXdlPack>{},
1110
1111 // TODO: make this field protected when b_scale_thread_copy_ is moved
1112 // here
1115 Number<KRepeat / KXdlPack>{},
1117
1118 protected:
1119 using Base::a_thread_copy_;
1120 using Base::a_thread_desc_;
1121 using Base::b_thread_copy_;
1122 using Base::b_thread_desc_;
1123 using Base::c_thread_desc_;
1124};
1125
1126} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:389
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:102
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10