blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::A_K1;
120 using Base::I0;
121 using Base::I1;
122 using Base::KRepeat;
123 using Base::MWaves;
124 using Base::NWaves;
125 using Base::WaveSize;
126 using Base::xdlops_gemm;
127 using typename Base::HotLoopInstList;
128
137 using Base::GetWaveIdx;
140
143
144 using Base::AMmaKStride;
145 using Base::APackedSize;
146 using Base::BMmaKStride;
147 using Base::BPackedSize;
148 using Base::KThreadChunk;
149
150 using Base::KXdlPack;
151 using Base::MXdlPack;
152 using Base::NXdlPack;
153
154 using AccType = typename Base::AccType;
155 using Tuple5 = typename Base::Tuple5;
158
159 static constexpr index_t PrefetchStages = 2;
160 static constexpr index_t LocalPrefetchStages = 2;
161 static constexpr index_t PrefillStages = 1;
162 static constexpr index_t GlobalBufferNum = 1;
163 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
164
165 static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
166 static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
167 static constexpr auto async_vmcnt =
169 static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
170
171 static constexpr auto ScalesPerKBlockSize =
172 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
173
174 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
175 static constexpr auto ScalesPerXdlopsRun =
176 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
177
178 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
179 static constexpr auto ScalesPerXdlopsRunPerThread =
180 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
181
183 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
184 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
185 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
186 "A scale pack data type too large!");
187 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
188 "B scale pack data type too large!");
191
192 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
193 {
194 return num_loop > PrefetchStages;
195 }
196
197 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
198 {
199 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
200 }
201
202 __device__ static constexpr auto HotLoopScheduler()
203 {
204 // A/B split schedule
205 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
206 constexpr auto num_ds_read_inst_a =
207 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
210
211 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
212 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
213 constexpr auto num_buffer_load_stage1 =
214 num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale;
215
216 constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a;
217
218 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
219 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
220
221 constexpr auto ds_read_a_issue_cycle =
222 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
223 constexpr auto ds_read_a_mfma_rate =
224 math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle);
225
226 // constexpr auto num_dsread_a_mfma =
227 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
228
229 constexpr auto num_total_stages = std::max(2, MRepeat);
230 if constexpr(num_total_stages > 2)
231 {
232
233 // Group num_mfma_perstage num_ds_read_a_perstage
234 // since we want to reuse a local register buffer
235 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
236 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
237
238 constexpr auto num_ds_read_a_mfma_perstage =
239 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
240
241 constexpr auto num_ds_read_a_prefetch_stages = 2;
242
243 constexpr auto buffer_load_perstage_more =
244 math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
245 constexpr auto buffer_load_perstage_less =
246 math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
247 constexpr auto buffer_load_perstage_stage2 =
248 math::integer_divide_floor((num_buffer_load_stage2), 2);
249
250 constexpr auto buffer_load_stages_more =
251 num_buffer_load_stage1 -
252 math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
253 ((num_total_stages - 2));
254
255 constexpr auto buffer_load_issue_point_interval_more =
256 num_mfma_perstage / buffer_load_perstage_more;
257 constexpr auto buffer_load_issue_point_interval_less =
258 num_mfma_perstage / buffer_load_perstage_less;
259 constexpr auto buffer_load_issue_point_interval_stage2 =
260 num_mfma_perstage / buffer_load_perstage_stage2;
261
262 // Stage 1
263 // global read more
265 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
266 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
267
268 if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
269 {
270 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
271 }
272
273 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
274 {
275 __builtin_amdgcn_sched_group_barrier(
276 0x100, ds_read_a_mfma_rate, 0); // DS read
277 }
278 });
279 });
280
281 // global read less
282 static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
283 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
284 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
285 if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
286 {
287 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
288 }
289 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
290 {
291 __builtin_amdgcn_sched_group_barrier(
292 0x100, ds_read_a_mfma_rate, 0); // DS read
293 }
294 });
295 });
296
297 // Stage 2, Sync
298 // lds synchronization, prefetch next loop local A
300 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
301 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
302 if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
303 {
304 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
305 }
306 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
307 {
308 __builtin_amdgcn_sched_group_barrier(
309 0x100, ds_read_a_mfma_rate, 0); // DS read
310 }
311 });
312 });
313 }
314 else
315 {
316 constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
319 constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
320 num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
321
322 // stage 1
323 constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
324
325 constexpr auto mfma_perstage_more =
326 math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
327 constexpr auto mfma_perstage_less =
328 math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
329
330 constexpr auto mfma_stages_more =
331 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
332
334 if constexpr(i < mfma_stages_more)
335 {
337 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
338 });
339 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
340 }
341 else
342 {
344 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
345 });
346 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
347 }
348 });
349
351 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
352 {
353 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
354 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
355 });
356 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
357 }
358 else
359 {
360 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
361 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
362 });
363 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
364 }
365 });
366
368 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
369 mfma_stages_more)
370 {
371 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
372 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
373 });
374 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
375 }
376 else
377 {
378 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
379 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
380 });
381 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
382 }
383 });
384
386 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
387 num_buffer_load_a_scale) < mfma_stages_more)
388 {
389 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
390 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
391 });
392 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
393 }
394 else
395 {
396 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
397 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
398 });
399 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
400 }
401 });
402
403 // stage 2
405 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
406 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
407 ds_read_a_mfma_rate)
408 {
409 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
410 }
411 else
412 {
413 __builtin_amdgcn_sched_group_barrier(
414 0x100,
415 num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
416 0); // DS read
417 }
418 });
419 }
420 }
421
422 template <bool HasMainLoop,
423 TailNumber TailNum,
424 typename AGridDesc,
425 typename ABlockDesc,
426 typename ABlockTransfer,
427 typename AGridBuffer,
428 typename ABlockBuffer,
429 typename ABlockTransferStep,
430 typename BGridDesc,
431 typename BBlockDesc,
432 typename BBlockTransfer,
433 typename BGridBuffer,
434 typename BBlockBuffer,
435 typename BBlockTransferStep,
436 typename CThreadBuffer,
437 typename AScaleGridBuffer,
438 typename AScaleGridDesc,
439 typename AScaleThreadTransfer,
440 typename BScaleGridBuffer,
441 typename BScaleGridDesc,
442 typename BScaleThreadTransfer>
443 __device__ void Run(
444 // ABlockCopy
445 const AGridDesc& a_grid_desc,
446 const ABlockDesc& a_block_desc,
447 ABlockTransfer& a_blockwise_copy,
448 const AGridBuffer& a_grid_buf,
449 ABlockBuffer& a_block_bufs,
450 const ABlockTransferStep& a_block_copy_step,
451 // BBlockCopy
452 const BGridDesc& b_grid_desc,
453 const BBlockDesc& b_block_desc,
454 BBlockTransfer& b_blockwise_copy,
455 const BGridBuffer& b_grid_buf,
456 BBlockBuffer& b_block_bufs,
457 const BBlockTransferStep& b_block_copy_step,
458 // CThread
459 CThreadBuffer& c_thread_buf,
460 // A and B scales
461 const AScaleGridDesc& a_scale_grid_desc,
462 AScaleThreadTransfer& a_scale_thread_copy,
463 const AScaleGridBuffer& a_scale_grid_buf,
464 const BScaleGridDesc& b_scale_grid_desc,
465 BScaleThreadTransfer& b_scale_thread_copy,
466 const BScaleGridBuffer& b_scale_grid_buf,
467 index_t num_loop) const
468 {
469 ignore = b_block_bufs;
471 a_thread_desc_.GetElementSpaceSize());
473 b_thread_desc_.GetElementSpaceSize());
474 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
475 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
476
478 a_scale_thread_desc.GetElementSpaceSize());
480 b_scale_thread_desc.GetElementSpaceSize());
481
482 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
483 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
484
485 // Global prefetch 1
486 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
487 b_blockwise_copy.Run(
488 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
489
490 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
491 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
492
493 // Prefetch a_scales
494 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
495 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
496 a_scale_thread_copy.Run(a_scale_grid_desc,
497 a_scale_grid_buf,
499 make_tuple(m0, k0, I0),
500 a_scale_thread_bufs(I0));
501
502 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
503 make_multi_index(0, I1, 0));
504 });
505 a_scale_thread_copy.MoveSrcSliceWindow(
506 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
507 });
508
509 // restore row id and advance to the next set of scales
510 a_scale_thread_copy.MoveSrcSliceWindow(
511 a_scale_grid_desc,
512 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
513
514 // Prefetch b_scales
515 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
516 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
517 b_scale_thread_copy.Run(b_scale_grid_desc,
518 b_scale_grid_buf,
520 make_tuple(n0, k0, I0),
521 b_scale_thread_bufs(I0));
522
523 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
524 make_multi_index(0, I1, 0));
525 });
526 b_scale_thread_copy.MoveSrcSliceWindow(
527 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
528 });
529
530 // restore col id and advance to the next set of scales
531 // NWaves * NPerXDL * NRepeat == NPerBlock
532 b_scale_thread_copy.MoveSrcSliceWindow(
533 b_scale_grid_desc,
534 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
535
536 // Local prefetch 1, sync the async load
537 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
540 static_for<0, KRepeat, 1>{}([&](auto k) {
541 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
542 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
543 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
544 [&](auto chunk) {
545 constexpr auto a_k_step_chunk =
546 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
547 a_thread_copy_.Run(
551 a_block_bufs(I0),
555 a_thread_buf);
556 });
557 });
558 });
559
560 // Global prefetch 2
561 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
562 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
563
564 // Initialize C
565 c_thread_buf.Clear();
566 __builtin_amdgcn_sched_barrier(0);
567 constexpr index_t SwitchM = MRepeat - LocalPrefetchStages;
568 // main body
569 if constexpr(HasMainLoop)
570 {
571 // loop over k with the step KPerBlock
572 index_t i = 0;
573 do
574 {
575 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
576 b_blockwise_copy.Run(b_grid_desc,
577 b_grid_buf,
578 b_block_desc,
579 b_block_origin_idx,
580 b_thread_bufs(scale_mem_buf));
581
582 // Prefetch a_scales
583 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
584 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
585 a_scale_thread_copy.Run(a_scale_grid_desc,
586 a_scale_grid_buf,
588 make_tuple(m0, k0, I0),
589 a_scale_thread_bufs(scale_mem_buf));
590
591 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
592 make_multi_index(0, I1, 0));
593 });
594 a_scale_thread_copy.MoveSrcSliceWindow(
595 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
596 });
597
598 // restore row id and advance to the next set of scales
599 a_scale_thread_copy.MoveSrcSliceWindow(
600 a_scale_grid_desc,
601 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
602
603 // Prefetch b_scales
604 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
605 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
606 b_scale_thread_copy.Run(b_scale_grid_desc,
607 b_scale_grid_buf,
609 make_tuple(n0, k0, I0),
610 b_scale_thread_bufs(scale_mem_buf));
611
612 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
613 make_multi_index(0, I1, 0));
614 });
615 b_scale_thread_copy.MoveSrcSliceWindow(
616 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
617 });
618
619 // restore col id and advance to the next set of scales
620 // NWaves * NPerXDL * NRepeat == NPerBlock
621 b_scale_thread_copy.MoveSrcSliceWindow(
622 b_scale_grid_desc,
623 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
624
625 // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
626 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
627
628 static_for<0, MRepeat, 1>{}([&](auto m0) {
629 constexpr auto im_major = m0 / MXdlPack;
630 constexpr auto im_minor = m0 % MXdlPack;
631 static_for<0, KRepeat, 1>{}([&](auto k0) {
632 constexpr auto ik_major = k0 / KXdlPack;
633 constexpr auto ik_minor = k0 % KXdlPack;
634 static_for<0, NRepeat, 1>{}([&](auto n0) {
635 constexpr auto in_major = n0 / NXdlPack;
636 constexpr auto in_minor = n0 % NXdlPack;
637
638 constexpr index_t a_scale_offset =
639 a_scale_thread_desc.CalculateOffset(
640 make_tuple(im_major, ik_major, I0));
641 constexpr index_t b_scale_offset =
642 b_scale_thread_desc.CalculateOffset(
643 make_tuple(in_major, ik_major, I0));
644
645 static_assert(0 < ScalesPerXdlopsRunPerThread,
646 "Must have at least one scale per Xdlops "
647 "per Thread.");
648
650 a_scale_thread_vec;
652 b_scale_thread_vec;
653
654 // Pack scale_thread_buf into scale_thread_vec
656 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
657 a_scale_thread_bufs(
658 scale_comp_buf)[Number<a_scale_offset + s>{}];
659 });
660
662 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
663 b_scale_thread_bufs(
664 scale_comp_buf)[Number<b_scale_offset + s>{}];
665 });
666
669
670 static_for<0, KPack, 1>{}([&](auto ik) {
671 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
672 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
673 make_tuple(I0, I0, im_minor, k0, ik))>{}];
674 b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
675 [scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
676 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
677 });
678
679 using mfma_input_type_a =
680 typename vector_type<ComputeTypeA,
681 xdlops_gemm.K1PerXdlops /
682 APackedSize>::type;
683
684 using mfma_input_type_b =
685 typename vector_type<ComputeTypeB,
686 xdlops_gemm.K1PerXdlops /
687 BPackedSize>::type;
688
689 using mfma_scale_input_type_a =
690 typename vector_type<AScaleDataType,
692 using mfma_scale_input_type_b =
693 typename vector_type<BScaleDataType,
695
696 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
697 make_tuple(im_major, in_major, im_minor, in_minor, 0));
698
699 // MFMA accumulation
700 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
701 ik_minor * NXdlPack + in_minor>(
702 a_thread_vec.template AsType<mfma_input_type_a>(),
703 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
704 b_thread_vec.template AsType<mfma_input_type_b>(),
705 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
706 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
707 });
708 });
709
710 if constexpr(m0.value == SwitchM)
711 {
712 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
714 a_blockwise_copy.Run(a_grid_desc,
715 a_grid_buf,
716 a_block_desc,
717 a_block_bufs(scale_comp_buf));
718 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
719 }
720
721 constexpr auto lds_buf =
722 m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf;
723
724 static_for<0, KRepeat, 1>{}([&](auto k) {
725 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
726 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
727 static_for<0,
728 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
729 1>{}([&](auto chunk) {
730 constexpr auto a_k_step_chunk =
731 k_step +
732 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
733 a_thread_copy_.Run(
736 (MRepeat / MXdlPack)>{},
737 I0,
739 I0,
741 a_block_bufs(Number<lds_buf>{}),
744 I0,
746 k,
748 a_thread_buf);
749 });
750 });
751 });
752
754 __builtin_amdgcn_sched_barrier(0);
755 };
756
757 LoopFunc(I0, I1);
758 LoopFunc(I1, I0);
759
760 i += 2;
761 } while(i < (num_loop - 2));
762 }
763
764 // tail
765 if constexpr(TailNum == TailNumber::Even)
766 {
767 b_blockwise_copy.Run(
768 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
769
770 // Prefetch a_scales
771 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
772 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
773 a_scale_thread_copy.Run(a_scale_grid_desc,
774 a_scale_grid_buf,
776 make_tuple(m0, k0, I0),
777 a_scale_thread_bufs(I1));
778
779 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
780 make_multi_index(0, I1, 0));
781 });
782 a_scale_thread_copy.MoveSrcSliceWindow(
783 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
784 });
785
786 // Prefetch b_scales
787 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
788 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
789 b_scale_thread_copy.Run(b_scale_grid_desc,
790 b_scale_grid_buf,
792 make_tuple(n0, k0, I0),
793 b_scale_thread_bufs(I1));
794
795 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
796 make_multi_index(0, I1, 0));
797 });
798 b_scale_thread_copy.MoveSrcSliceWindow(
799 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
800 });
801
802 static_for<0, MRepeat, 1>{}([&](auto m0) {
803 constexpr auto im_major = m0 / MXdlPack;
804 constexpr auto im_minor = m0 % MXdlPack;
805 static_for<0, KRepeat, 1>{}([&](auto k0) {
806 constexpr auto ik_major = k0 / KXdlPack;
807 constexpr auto ik_minor = k0 % KXdlPack;
808 static_for<0, NRepeat, 1>{}([&](auto n0) {
809 constexpr auto in_major = n0 / NXdlPack;
810 constexpr auto in_minor = n0 % NXdlPack;
811
812 constexpr index_t a_scale_offset =
813 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
814 constexpr index_t b_scale_offset =
815 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
816
817 static_assert(0 < ScalesPerXdlopsRunPerThread,
818 "Must have at least one scale per Xdlops "
819 "per Thread.");
820
823
824 // Pack scale_thread_buf into scale_thread_vec
826 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
827 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
828 });
829
831 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
832 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
833 });
834
837
838 static_for<0, KPack, 1>{}([&](auto ik) {
839 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
840 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
841 make_tuple(I0, I0, im_minor, k0, ik))>{}];
842 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
843 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
844 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
845 });
846
847 using mfma_input_type_a =
848 typename vector_type<ComputeTypeA,
849 xdlops_gemm.K1PerXdlops / APackedSize>::type;
850
851 using mfma_input_type_b =
852 typename vector_type<ComputeTypeB,
853 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
854
855 using mfma_scale_input_type_a =
857 using mfma_scale_input_type_b =
859
860 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
861 make_tuple(im_major, in_major, im_minor, in_minor, 0));
862
863 // MFMA accumulation
864 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
865 ik_minor * NXdlPack + in_minor>(
866 a_thread_vec.template AsType<mfma_input_type_a>(),
867 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
868 b_thread_vec.template AsType<mfma_input_type_b>(),
869 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
870 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
871 });
872 });
873 if constexpr(m0.value == SwitchM)
874 {
875 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
877 }
878
879 constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
880
881 static_for<0, KRepeat, 1>{}([&](auto k) {
882 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
883 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
884 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
885 [&](auto chunk) {
886 constexpr auto a_k_step_chunk =
887 k_step +
888 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
889 a_thread_copy_.Run(
892 (MRepeat / MXdlPack)>{},
893 I0,
895 I0,
897 a_block_bufs(Number<lds_buf>{}),
901 a_thread_buf);
902 });
903 });
904 });
905
906 static_for<0, MRepeat, 1>{}([&](auto m0) {
907 constexpr auto im_major = m0 / MXdlPack;
908 constexpr auto im_minor = m0 % MXdlPack;
909 static_for<0, KRepeat, 1>{}([&](auto k0) {
910 constexpr auto ik_major = k0 / KXdlPack;
911 constexpr auto ik_minor = k0 % KXdlPack;
912 static_for<0, NRepeat, 1>{}([&](auto n0) {
913 constexpr auto in_major = n0 / NXdlPack;
914 constexpr auto in_minor = n0 % NXdlPack;
915
916 constexpr index_t a_scale_offset =
917 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
918 constexpr index_t b_scale_offset =
919 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
920
921 static_assert(0 < ScalesPerXdlopsRunPerThread,
922 "Must have at least one scale per Xdlops "
923 "per Thread.");
924
927
928 // Pack scale_thread_buf into scale_thread_vec
930 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
931 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
932 });
933
935 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
936 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
937 });
938
941
942 static_for<0, KPack, 1>{}([&](auto ik) {
943 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
944 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
945 make_tuple(I0, I0, im_minor, k0, ik))>{}];
946 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
947 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
948 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
949 });
950
951 using mfma_input_type_a =
952 typename vector_type<ComputeTypeA,
953 xdlops_gemm.K1PerXdlops / APackedSize>::type;
954
955 using mfma_input_type_b =
956 typename vector_type<ComputeTypeB,
957 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
958
959 using mfma_scale_input_type_a =
961 using mfma_scale_input_type_b =
963
964 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
965 make_tuple(im_major, in_major, im_minor, in_minor, 0));
966
967 // MFMA accumulation
968 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
969 ik_minor * NXdlPack + in_minor>(
970 a_thread_vec.template AsType<mfma_input_type_a>(),
971 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
972 b_thread_vec.template AsType<mfma_input_type_b>(),
973 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
974 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
975 });
976 });
977 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
978 {
979 static_for<0, KRepeat, 1>{}([&](auto k) {
980 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
981 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
982 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
983 [&](auto chunk) {
984 constexpr auto a_k_step_chunk =
985 k_step +
986 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
987 a_thread_copy_.Run(
990 (MRepeat / MXdlPack)>{},
991 I0,
993 I0,
995 a_block_bufs(I1),
998 I0,
1000 k,
1002 a_thread_buf);
1003 });
1004 });
1005 }
1006 });
1007 }
1008 else if constexpr(TailNum == TailNumber::Odd)
1009 {
1010 static_for<0, MRepeat, 1>{}([&](auto m0) {
1011 constexpr auto im_major = m0 / MXdlPack;
1012 constexpr auto im_minor = m0 % MXdlPack;
1013 static_for<0, KRepeat, 1>{}([&](auto k0) {
1014 constexpr auto ik_major = k0 / KXdlPack;
1015 constexpr auto ik_minor = k0 % KXdlPack;
1016 static_for<0, NRepeat, 1>{}([&](auto n0) {
1017 constexpr auto in_major = n0 / NXdlPack;
1018 constexpr auto in_minor = n0 % NXdlPack;
1019
1020 constexpr index_t a_scale_offset =
1021 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
1022 constexpr index_t b_scale_offset =
1023 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
1024
1025 static_assert(0 < ScalesPerXdlopsRunPerThread,
1026 "Must have at least one scale per Xdlops "
1027 "per Thread.");
1028
1031
1032 // Pack scale_thread_buf into scale_thread_vec
1034 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1035 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1036 });
1037
1039 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1040 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1041 });
1042
1045
1046 static_for<0, KPack, 1>{}([&](auto ik) {
1047 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1048 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1049 make_tuple(I0, I0, im_minor, k0, ik))>{}];
1050 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1051 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1052 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
1053 });
1054
1055 using mfma_input_type_a =
1056 typename vector_type<ComputeTypeA,
1057 xdlops_gemm.K1PerXdlops / APackedSize>::type;
1058
1059 using mfma_input_type_b =
1060 typename vector_type<ComputeTypeB,
1061 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
1062
1063 using mfma_scale_input_type_a =
1065 using mfma_scale_input_type_b =
1067
1068 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1069 make_tuple(im_major, in_major, im_minor, in_minor, 0));
1070
1071 // MFMA accumulation
1072 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
1073 ik_minor * NXdlPack + in_minor>(
1074 a_thread_vec.template AsType<mfma_input_type_a>(),
1075 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
1076 b_thread_vec.template AsType<mfma_input_type_b>(),
1077 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
1078 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1079 });
1080 });
1081 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
1082 {
1083 static_for<0, KRepeat, 1>{}([&](auto k) {
1084 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
1085 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
1086 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
1087 [&](auto chunk) {
1088 constexpr auto a_k_step_chunk =
1089 k_step +
1090 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1091 a_thread_copy_.Run(
1094 (MRepeat / MXdlPack)>{},
1095 I0,
1097 I0,
1099 a_block_bufs(I0),
1101 make_tuple(I0,
1102 I0,
1104 k,
1106 a_thread_buf);
1107 });
1108 });
1109 }
1110 });
1111 }
1112 }
1113
1114 // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack]
1115 // Order: 1 0 3 2 4
1116 static constexpr auto ARegBuf = 2;
1119
1123 decltype(a_thread_desc_),
1126 4,
1127 A_K1,
1128 A_K1>;
1130
1131 // TODO: make this field protected when a_scale_thread_copy_ is moved
1132 // here
1135 Number<KRepeat / KXdlPack>{},
1137
1138 // TODO: make this field protected when b_scale_thread_copy_ is moved
1139 // here
1142 Number<KRepeat / KXdlPack>{},
1144
1145 protected:
1146 // using Base::a_thread_copy_;
1147 // using Base::a_thread_desc_;
1148 using Base::b_thread_copy_;
1149 using Base::b_thread_desc_;
1150 using Base::c_thread_desc_;
1151};
1152
1153} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp:102
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_m3_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, KThreadChunk >, Sequence< 0, 1, 2, 3, 4 >, 4, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp:1120
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_bufs, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp:443
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp:38
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10