blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
38{
39};
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::I0;
120 using Base::I1;
121 using Base::KRepeat;
122 using Base::MWaves;
123 using Base::NWaves;
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
126 using typename Base::HotLoopInstList;
127
136 using Base::GetWaveIdx;
139
142
143 using Base::AMmaKStride;
144 using Base::APackedSize;
145 using Base::BMmaKStride;
146 using Base::BPackedSize;
147 using Base::KThreadChunk;
148
149 using Base::KXdlPack;
150 using Base::MXdlPack;
151 using Base::NXdlPack;
152
153 using AccType = typename Base::AccType;
154 using Tuple5 = typename Base::Tuple5;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 1;
161
162 static constexpr auto ScalesPerKBlockSize =
163 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
164
165 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
166 static constexpr auto ScalesPerXdlopsRun =
167 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
168
169 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
170 static constexpr auto ScalesPerXdlopsRunPerThread =
171 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
172
174 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
175 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
176 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177 "A scale pack data type too large!");
178 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179 "B scale pack data type too large!");
182
183 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
184 {
185 return num_loop > PrefetchStages;
186 }
187
188 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
189 {
190 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
191 }
192
193 __device__ static constexpr auto HotLoopScheduler()
194 {
195 // A/B split schedule
196 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
197 constexpr auto num_ds_read_inst_a =
198 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
201 constexpr auto num_ds_read_inst_b =
202 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
205
206 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
207 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
208
209 constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
210 constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
211
212 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
213
214 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
215 constexpr auto ds_read_a_issue_cycle =
216 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
217 constexpr auto ds_read_b_issue_cycle =
218 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
219
220 constexpr auto ds_read_a_mfma_rate =
221 (mfma_cycle - 8 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
222 constexpr auto ds_read_b_mfma_rate =
223 (mfma_cycle - 8 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
224
225 constexpr auto num_dsread_a_mfma =
226 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
227 constexpr auto num_dsread_b_mfma =
228 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
229
230 // stage 1
231 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
232 constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
233 num_buffer_load_a_scale + num_buffer_load_b_scale;
234
235 constexpr auto mfma_perstage_more =
236 math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
237 constexpr auto mfma_perstage_less =
238 math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
239
240 constexpr auto mfma_stages_more =
241 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
242
244 if constexpr(i < mfma_stages_more)
245 {
246 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
247 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
248 });
249 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
250 }
251 else
252 {
253 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
254 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
255 });
256 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
257 }
258 });
259
261 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
262 {
263 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 });
266 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
267 }
268 else
269 {
270 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
271 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
272 });
273 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
274 }
275 });
276
278 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
279 {
280 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
281 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
282 });
283 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
284 }
285 else
286 {
287 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
288 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
289 });
290 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
291 }
292 });
293
295 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
296 num_buffer_load_a_scale) < mfma_stages_more)
297 {
298 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
299 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
300 });
301 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
302 }
303 else
304 {
305 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
306 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
307 });
308 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
309 }
310 });
311
312 // stage 2
314 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
315 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
316 ds_read_a_mfma_rate)
317 {
318 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
319 }
320 else
321 {
322 __builtin_amdgcn_sched_group_barrier(0x100,
323 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
324 ds_read_a_mfma_rate,
325 0); // DS read
326 }
327 });
328
330 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
331 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
332 ds_read_b_mfma_rate)
333 {
334 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
335 }
336 else
337 {
338 __builtin_amdgcn_sched_group_barrier(0x100,
339 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
340 ds_read_b_mfma_rate,
341 0); // DS read
342 }
343 });
344 }
345
346 template <bool HasMainLoop,
347 TailNumber TailNum,
348 typename AGridDesc,
349 typename ABlockDesc,
350 typename ABlockTransfer,
351 typename AGridBuffer,
352 typename ABlockBuffer,
353 typename ABlockTransferStep,
354 typename BGridDesc,
355 typename BBlockDesc,
356 typename BBlockTransfer,
357 typename BGridBuffer,
358 typename BBlockBuffer,
359 typename BBlockTransferStep,
360 typename CThreadBuffer,
361 typename AScaleGridBuffer,
362 typename AScaleGridDesc,
363 typename AScaleThreadTransfer,
364 typename BScaleGridBuffer,
365 typename BScaleGridDesc,
366 typename BScaleThreadTransfer>
367 __device__ void Run(
368 // ABlockCopy
369 const AGridDesc& a_grid_desc,
370 const ABlockDesc& a_block_desc,
371 ABlockTransfer& a_blockwise_copy,
372 const AGridBuffer& a_grid_buf,
373 ABlockBuffer& a_block_bufs,
374 const ABlockTransferStep& a_block_copy_step,
375 // BBlockCopy
376 const BGridDesc& b_grid_desc,
377 const BBlockDesc& b_block_desc,
378 BBlockTransfer& b_blockwise_copy,
379 const BGridBuffer& b_grid_buf,
380 BBlockBuffer& b_block_bufs,
381 const BBlockTransferStep& b_block_copy_step,
382 // CThread
383 CThreadBuffer& c_thread_buf,
384 // A and B scales
385 const AScaleGridDesc& a_scale_grid_desc,
386 AScaleThreadTransfer& a_scale_thread_copy,
387 const AScaleGridBuffer& a_scale_grid_buf,
388 const BScaleGridDesc& b_scale_grid_desc,
389 BScaleThreadTransfer& b_scale_thread_copy,
390 const BScaleGridBuffer& b_scale_grid_buf,
391 index_t num_loop) const
392 {
394 a_thread_desc_.GetElementSpaceSize());
396 b_thread_desc_.GetElementSpaceSize());
397
399 a_scale_thread_desc.GetElementSpaceSize());
400
402 b_scale_thread_desc.GetElementSpaceSize());
403
404 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
405 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
406
407 // Global prefetch 1
408 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
409 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0));
410
411 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
412 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
413
414 // Prefetch a_scales
415 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
416 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
417 a_scale_thread_copy.Run(a_scale_grid_desc,
418 a_scale_grid_buf,
420 make_tuple(m0, k0, I0),
421 a_scale_thread_bufs(I0));
422
423 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
424 make_multi_index(0, I1, 0));
425 });
426 a_scale_thread_copy.MoveSrcSliceWindow(
427 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
428 });
429
430 // restore row id and advance to the next set of scales
431 a_scale_thread_copy.MoveSrcSliceWindow(
432 a_scale_grid_desc,
433 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
434
435 // Prefetch b_scales
436 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
437 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
438 b_scale_thread_copy.Run(b_scale_grid_desc,
439 b_scale_grid_buf,
441 make_tuple(n0, k0, I0),
442 b_scale_thread_bufs(I0));
443
444 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
445 make_multi_index(0, I1, 0));
446 });
447 b_scale_thread_copy.MoveSrcSliceWindow(
448 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
449 });
450
451 // restore col id and advance to the next set of scales
452 // NWaves * NPerXDL * NRepeat == NPerBlock
453 b_scale_thread_copy.MoveSrcSliceWindow(
454 b_scale_grid_desc,
455 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
456
457 // Local prefetch 1, sync the async load
458 __builtin_amdgcn_s_waitcnt(3952);
460 static_for<0, KRepeat, 1>{}([&](auto k) {
461 constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
462 static_for<0, MRepeat, 1>{}([&](auto m0) {
463 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
464 [&](auto chunk) {
465 constexpr auto a_k_step_chunk =
466 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
469 I0,
471 I0,
473 a_block_bufs(I0),
476 I0,
478 k,
480 a_thread_buf);
481 });
482 });
483 static_for<0, NRepeat, 1>{}([&](auto n0) {
484 // read block data in chunks to assemble correct thread vectors
485 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
486 [&](auto chunk) {
487 constexpr auto b_k_step_chunk =
488 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
491 I0,
493 I0,
495 b_block_bufs(I0),
498 I0,
500 k,
502 b_thread_buf);
503 });
504 });
505 });
506
507 // Global prefetch 2
508 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
509 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1));
510
511 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
512 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
513
514 // Initialize C
515 c_thread_buf.Clear();
516 __builtin_amdgcn_sched_barrier(0);
517
518 // main body
519 if constexpr(HasMainLoop)
520 {
521 // loop over k with the step KPerBlock
522 index_t i = 0;
523 do
524 {
525 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
526 __builtin_amdgcn_s_waitcnt(3952);
528
529 a_blockwise_copy.Run(
530 a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
531 b_blockwise_copy.Run(
532 b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
533
534 // Prefetch a_scales
535 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
536 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
537 a_scale_thread_copy.Run(a_scale_grid_desc,
538 a_scale_grid_buf,
540 make_tuple(m0, k0, I0),
541 a_scale_thread_bufs(scale_mem_buf));
542
543 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
544 make_multi_index(0, I1, 0));
545 });
546 a_scale_thread_copy.MoveSrcSliceWindow(
547 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
548 });
549
550 // restore row id and advance to the next set of scales
551 a_scale_thread_copy.MoveSrcSliceWindow(
552 a_scale_grid_desc,
553 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
554
555 // Prefetch b_scales
556 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
557 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
558 b_scale_thread_copy.Run(b_scale_grid_desc,
559 b_scale_grid_buf,
561 make_tuple(n0, k0, I0),
562 b_scale_thread_bufs(scale_mem_buf));
563
564 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
565 make_multi_index(0, I1, 0));
566 });
567 b_scale_thread_copy.MoveSrcSliceWindow(
568 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
569 });
570
571 // restore col id and advance to the next set of scales
572 // NWaves * NPerXDL * NRepeat == NPerBlock
573 b_scale_thread_copy.MoveSrcSliceWindow(
574 b_scale_grid_desc,
575 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
576
577 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
578 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
579
580 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
581 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
582 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
583 constexpr index_t a_scale_offset =
584 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
585 constexpr index_t b_scale_offset =
586 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
587
588 static_assert(0 < ScalesPerXdlopsRunPerThread,
589 "Must have at least one scale per Xdlops "
590 "per Thread.");
591
593 a_scale_thread_vec;
595 b_scale_thread_vec;
596
597 // Pack scale_thread_buf into scale_thread_vec
599 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
600 a_scale_thread_bufs(
601 scale_comp_buf)[Number<a_scale_offset + s>{}];
602 });
603
605 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
606 b_scale_thread_bufs(
607 scale_comp_buf)[Number<b_scale_offset + s>{}];
608 });
609
610 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
611 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
612 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
613 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
614
617
618 static_for<0, KPack, 1>{}([&](auto ik) {
619 a_thread_vec.template AsType<ComputeTypeA>()(
620 ik) = a_thread_buf
621 [Number<a_thread_desc_.CalculateOffset(
622 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
623 b_thread_vec.template AsType<ComputeTypeB>()(
624 ik) = b_thread_buf
625 [Number<b_thread_desc_.CalculateOffset(
626 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
627 });
628
629 using mfma_input_type_a = typename vector_type< //
631 xdlops_gemm.K1PerXdlops / APackedSize>::type;
632
633 using mfma_input_type_b = typename vector_type< //
635 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
636
637 using mfma_scale_input_type_a = typename vector_type< //
638 AScaleDataType,
640 using mfma_scale_input_type_b = typename vector_type< //
641 BScaleDataType,
643
644 constexpr index_t c_offset =
645 c_thread_desc_.CalculateOffset(
646 make_tuple(m0, n0, imxdl, inxdl, 0));
647
648 // MFMA accumulation
649 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
650 ikxdl * NXdlPack + inxdl>(
651 a_thread_vec.template AsType<mfma_input_type_a>(),
652 a_scale_thread_vec
653 .template AsType<mfma_scale_input_type_a>(),
654 b_thread_vec.template AsType<mfma_input_type_b>(),
655 b_scale_thread_vec
656 .template AsType<mfma_scale_input_type_b>(),
657 c_thread_buf.GetVectorTypeReference(
659 });
660 });
661 });
662 });
663 });
664 });
665
666 // k indexes mapping to threads for 32x32x64:
667 // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
668 // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
669 // k = 0 k = 1
670
671 // k indexes mapping to threads for 16x16x128:
672 // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
673 // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
674 // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
675 // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
676 // k = 0 k = 1
677 // __builtin_amdgcn_s_waitcnt(3952);
678 // block_sync_lds();
679 static_for<0, KRepeat, 1>{}([&](auto k) {
680 constexpr auto k_step =
681 k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
682 static_for<0, MRepeat, 1>{}([&](auto m0) {
683 static_for<0,
684 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
685 1>{}([&](auto chunk) {
686 constexpr auto a_k_step_chunk =
687 k_step +
688 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
691 I0,
693 I0,
695 a_block_bufs(scale_mem_buf),
698 I0,
700 k,
702 a_thread_buf);
703 });
704 });
705 static_for<0, NRepeat, 1>{}([&](auto n0) {
706 // read block data in chunks to assemble correct thread vectors
707 static_for<0,
708 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
709 1>{}([&](auto chunk) {
710 constexpr auto b_k_step_chunk =
711 k_step +
712 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
715 I0,
717 I0,
719 b_block_bufs(scale_mem_buf),
722 I0,
724 k,
726 b_thread_buf);
727 });
728 });
729 });
730
732 __builtin_amdgcn_sched_barrier(0);
733 };
734
735 LoopFunc(I0, I1);
736 LoopFunc(I1, I0);
737
738 i += 2;
739 } while(i < (num_loop - 2));
740 }
741
742 // tail
743 if constexpr(TailNum == TailNumber::Even)
744 {
745 // Prefetch a_scales
746 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
747 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
748 a_scale_thread_copy.Run(a_scale_grid_desc,
749 a_scale_grid_buf,
751 make_tuple(m0, k0, I0),
752 a_scale_thread_bufs(I1));
753
754 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
755 make_multi_index(0, I1, 0));
756 });
757 a_scale_thread_copy.MoveSrcSliceWindow(
758 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
759 });
760
761 // Prefetch b_scales
762 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
763 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
764 b_scale_thread_copy.Run(b_scale_grid_desc,
765 b_scale_grid_buf,
767 make_tuple(n0, k0, I0),
768 b_scale_thread_bufs(I1));
769
770 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
771 make_multi_index(0, I1, 0));
772 });
773 b_scale_thread_copy.MoveSrcSliceWindow(
774 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
775 });
776
777 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
778 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
779 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
780 constexpr index_t a_scale_offset =
781 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
782 constexpr index_t b_scale_offset =
783 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
784
785 static_assert(0 < ScalesPerXdlopsRunPerThread,
786 "Must have at least one scale per Xdlops "
787 "per Thread.");
788
791
792 // Pack scale_thread_buf into scale_thread_vec
794 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
795 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
796 });
797
799 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
800 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
801 });
802
803 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
804 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
805 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
806 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
807
810
811 static_for<0, KPack, 1>{}([&](auto ik) {
812 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
813 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
814 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
815 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
816 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
817 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
818 });
819
820 using mfma_input_type_a = typename vector_type< //
822 xdlops_gemm.K1PerXdlops / APackedSize>::type;
823
824 using mfma_input_type_b = typename vector_type< //
826 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
827
828 using mfma_scale_input_type_a = typename vector_type< //
829 AScaleDataType,
831 using mfma_scale_input_type_b = typename vector_type< //
832 BScaleDataType,
834
835 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
836 make_tuple(m0, n0, imxdl, inxdl, 0));
837
838 // MFMA accumulation
839 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
840 ikxdl * NXdlPack + inxdl>(
841 a_thread_vec.template AsType<mfma_input_type_a>(),
842 a_scale_thread_vec
843 .template AsType<mfma_scale_input_type_a>(),
844 b_thread_vec.template AsType<mfma_input_type_b>(),
845 b_scale_thread_vec
846 .template AsType<mfma_scale_input_type_b>(),
847 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
848 });
849 });
850 });
851 });
852 });
853 });
854
855 __builtin_amdgcn_s_waitcnt(3952);
857
858 static_for<0, KRepeat, 1>{}([&](auto k) {
859 constexpr auto k_step =
860 k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
861 static_for<0, MRepeat, 1>{}([&](auto m0) {
862 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
863 [&](auto chunk) {
864 constexpr auto a_k_step_chunk =
865 k_step +
866 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
869 I0,
871 I0,
873 a_block_bufs(I1),
876 I0,
878 k,
880 a_thread_buf);
881 });
882 });
883 static_for<0, NRepeat, 1>{}([&](auto n0) {
884 // read block data in chunks to assemble correct thread vectors
885 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
886 [&](auto chunk) {
887 constexpr auto b_k_step_chunk =
888 k_step +
889 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
892 I0,
894 I0,
896 b_block_bufs(I1),
899 I0,
901 k,
903 b_thread_buf);
904 });
905 });
906 });
907
908 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
909 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
910 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
911 constexpr index_t a_scale_offset =
912 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
913 constexpr index_t b_scale_offset =
914 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
915
916 static_assert(0 < ScalesPerXdlopsRunPerThread,
917 "Must have at least one scale per Xdlops "
918 "per Thread.");
919
922
923 // Pack scale_thread_buf into scale_thread_vec
925 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
926 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
927 });
928
930 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
931 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
932 });
933
934 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
935 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
936 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
937 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
938
941
942 static_for<0, KPack, 1>{}([&](auto ik) {
943 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
944 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
945 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
946 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
947 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
948 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
949 });
950
951 using mfma_input_type_a = typename vector_type< //
953 xdlops_gemm.K1PerXdlops / APackedSize>::type;
954
955 using mfma_input_type_b = typename vector_type< //
957 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
958
959 using mfma_scale_input_type_a = typename vector_type< //
960 AScaleDataType,
962 using mfma_scale_input_type_b = typename vector_type< //
963 BScaleDataType,
965
966 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
967 make_tuple(m0, n0, imxdl, inxdl, 0));
968
969 // MFMA accumulation
970 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
971 ikxdl * NXdlPack + inxdl>(
972 a_thread_vec.template AsType<mfma_input_type_a>(),
973 a_scale_thread_vec
974 .template AsType<mfma_scale_input_type_a>(),
975 b_thread_vec.template AsType<mfma_input_type_b>(),
976 b_scale_thread_vec
977 .template AsType<mfma_scale_input_type_b>(),
978 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
979 });
980 });
981 });
982 });
983 });
984 });
985 }
986 else if constexpr(TailNum == TailNumber::Odd)
987 {
988 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
989 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
990 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
991 constexpr index_t a_scale_offset =
992 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
993 constexpr index_t b_scale_offset =
994 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
995
996 static_assert(0 < ScalesPerXdlopsRunPerThread,
997 "Must have at least one scale per Xdlops "
998 "per Thread.");
999
1002
1003 // Pack scale_thread_buf into scale_thread_vec
1005 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1006 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1007 });
1008
1010 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1011 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1012 });
1013
1014 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1015 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1016 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1017 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1018
1021
1022 static_for<0, KPack, 1>{}([&](auto ik) {
1023 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1024 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1025 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1026 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1027 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1028 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1029 });
1030
1031 using mfma_input_type_a = typename vector_type< //
1033 xdlops_gemm.K1PerXdlops / APackedSize>::type;
1034
1035 using mfma_input_type_b = typename vector_type< //
1037 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
1038
1039 using mfma_scale_input_type_a = typename vector_type< //
1040 AScaleDataType,
1042 using mfma_scale_input_type_b = typename vector_type< //
1043 BScaleDataType,
1045
1046 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1047 make_tuple(m0, n0, imxdl, inxdl, 0));
1048
1049 // MFMA accumulation
1050 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1051 ikxdl * NXdlPack + inxdl>(
1052 a_thread_vec.template AsType<mfma_input_type_a>(),
1053 a_scale_thread_vec
1054 .template AsType<mfma_scale_input_type_a>(),
1055 b_thread_vec.template AsType<mfma_input_type_b>(),
1056 b_scale_thread_vec
1057 .template AsType<mfma_scale_input_type_b>(),
1058 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1059 });
1060 });
1061 });
1062 });
1063 });
1064 });
1065 }
1066 }
1067
1068 // TODO: make this field protected when a_scale_thread_copy_ is moved
1069 // here
1070 static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1072 Number<KRepeat / KXdlPack>{},
1074
1075 // TODO: make this field protected when b_scale_thread_copy_ is moved
1076 // here
1077 static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1079 Number<KRepeat / KXdlPack>{},
1081
1082 protected:
1083 using Base::a_thread_copy_;
1084 using Base::a_thread_desc_;
1085 using Base::b_thread_copy_;
1086 using Base::b_thread_desc_;
1087 using Base::c_thread_desc_;
1088};
1089
1090} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:33
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:389
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_bufs, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp:367
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10