blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimimal pipeline with highest resource request
11// GlobalPrefetchStages: 4
12// LocalPreFillStages: 2
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 2
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::I1;
123 using Base::KRepeat;
124 using Base::xdlops_gemm;
125 using typename Base::HotLoopInstList;
126
138
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
144
146
147 static constexpr index_t PrefetchStages = 3;
148 static constexpr index_t PrefillStages = 2;
149 static constexpr index_t GlobalBufferNum = 1;
150 static constexpr index_t HotloopUnroll = 2;
151
152 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
153 {
154 return num_loop > PrefetchStages;
155 }
156
157 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
158 {
159 if(num_loop % HotloopUnroll == 1)
160 {
161 return TailNumber::Odd;
162 }
163 else
164 {
165 return TailNumber::Even;
166 }
167 }
168
169 __device__ static constexpr void HotLoopScheduler()
170 {
171 // TODO: Take data type into consideration as pipe ver 3
172 // A-B splited schedule
173 constexpr auto num_ds_read_inst_a =
174 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
177 constexpr auto num_ds_read_inst_b =
178 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
181
182 constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
183 constexpr auto num_dswrite_per_issue_a =
184 (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
185 constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
186
187 constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
188 constexpr auto num_dswrite_per_issue_b =
189 (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
190 constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
191
192 constexpr auto num_mfma_per_issue =
193 HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
194
195 static_for<0, num_issue_a, 1>{}([&](auto i) {
196 ignore = i;
197 static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
198 ignore = idsread;
199 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
200 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
201 });
202
203 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
204 ignore = idswrite;
205 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
206 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
207 });
208
209 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
210 __builtin_amdgcn_sched_group_barrier(0x008,
211 num_mfma_per_issue - num_dsread_per_issue_a -
212 num_dswrite_per_issue_a,
213 0); // MFMA
214 });
215
216 static_for<0, num_issue_b, 1>{}([&](auto i) {
217 ignore = i;
218 static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
219 ignore = idsread;
220 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
221 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
222 });
223
224 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225 ignore = idswrite;
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228 });
229
230 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
231 __builtin_amdgcn_sched_group_barrier(0x008,
232 num_mfma_per_issue - num_dsread_per_issue_a -
233 num_dswrite_per_issue_b,
234 0); // MFMA
235 });
236 __builtin_amdgcn_sched_barrier(0);
237 }
238
239 template <bool HasMainLoop,
240 TailNumber TailNum,
241 typename AGridDesc,
242 typename ABlockDesc,
243 typename ABlockTransfer,
244 typename AGridBuffer,
245 typename ABlockBuffer,
246 typename ABlockTransferStep,
247 typename BGridDesc,
248 typename BBlockDesc,
249 typename BBlockTransfer,
250 typename BGridBuffer,
251 typename BBlockBuffer,
252 typename BBlockTransferStep,
253 typename CThreadBuffer,
254 typename BScaleGridBuffer,
255 typename BScaleGridDesc,
256 typename BScaleThreadDesc,
257 typename BScaleThreadTransfer,
258 typename BScaleThreadTransferStep>
259 __device__ void Run(const AGridDesc& a_grid_desc,
260 const ABlockDesc& a_block_desc,
261 ABlockTransfer& a_blockwise_copy,
262 const AGridBuffer& a_grid_buf,
263 ABlockBuffer& a_block_buf,
264 const ABlockTransferStep& a_block_copy_step,
265 const BGridDesc& b_grid_desc,
266 const BBlockDesc& b_block_desc,
267 BBlockTransfer& b_blockwise_copy,
268 const BGridBuffer& b_grid_buf,
269 BBlockBuffer& b_block_buf,
270 const BBlockTransferStep& b_block_copy_step,
271 CThreadBuffer& c_thread_buf,
272 // BScaleThreadCopy
273 const BScaleGridDesc& b_scale_grid_desc,
274 const BScaleThreadDesc& b_scale_thread_desc,
275 BScaleThreadTransfer& b_scale_thread_copy,
276 const BScaleGridBuffer& b_scale_grid_buf,
277 const BScaleThreadTransferStep& b_scale_thread_copy_step,
278 // num loop
279 index_t num_loop,
280 index_t num_loop_per_scale) const
281 {
283 a_thread_desc_.GetElementSpaceSize());
285 b_thread_desc_.GetElementSpaceSize());
286
287 // B scale buffer
289 b_scale_thread_desc.GetElementSpaceSize());
290
291 StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
292 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
293 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
294
295 // Global prefetch 1
296 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
297 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
298
299 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
300 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
301
302 static_for<0, NRepeat, 1>{}([&](auto n0) {
303 b_scale_thread_copy.Run(b_scale_grid_desc,
304 b_scale_grid_buf,
305 b_scale_thread_desc,
306 make_tuple(n0, I0),
307 b_scale_thread_bufs(I0));
308
309 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
310 b_scale_thread_copy_step.At(Number<0>{}));
311 });
312
313 if(num_loop_per_scale == 1)
314 {
315 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
316 b_scale_thread_copy_step.At(Number<2>{}));
317 }
318 else
319 {
320 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
321 b_scale_thread_copy_step.At(Number<1>{}));
322 }
323
324 // Local prefill 1
325 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
326 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
327
328 // Global prefetch 2
329 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
330 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
331
332 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
333 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
334
335 static_for<0, NRepeat, 1>{}([&](auto n0) {
336 b_scale_thread_copy.Run(b_scale_grid_desc,
337 b_scale_grid_buf,
338 b_scale_thread_desc,
339 make_tuple(n0, I0),
340 b_scale_thread_bufs(I1));
341
342 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
343 b_scale_thread_copy_step.At(Number<0>{}));
344 });
345
346 if(2 % num_loop_per_scale == 0)
347 {
348 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
349 b_scale_thread_copy_step.At(Number<2>{}));
350 }
351 else
352 {
353 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
354 b_scale_thread_copy_step.At(Number<1>{}));
355 }
356
357 // Local prefetch 1
359 static_for<0, KRepeat, 1>{}([&](auto k) {
360 static_for<0, MRepeat, 1>{}([&](auto m0) {
363 a_block_buf.At(I0),
365 make_tuple(m0, I0, k, I0),
366 a_thread_bufs(I0));
367 static_for<0, NRepeat, 1>{}([&](auto n0) {
370 b_block_buf.At(I0),
371 b_scale_thread_bufs(I0)[n0],
373 make_tuple(n0, I0, k, I0),
374 b_thread_bufs(I0));
375 });
376 });
377 });
378
379 // Local prefill 2
380 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
381 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
382
383 // Global prefetch 3
384 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
385 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
386
387 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
388 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
389
390 static_for<0, NRepeat, 1>{}([&](auto n0) {
391 b_scale_thread_copy.Run(b_scale_grid_desc,
392 b_scale_grid_buf,
393 b_scale_thread_desc,
394 make_tuple(n0, I0),
395 b_scale_thread_bufs(I0));
396
397 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
398 b_scale_thread_copy_step.At(Number<0>{}));
399 });
400
401 if(3 % num_loop_per_scale == 0)
402 {
403 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
404 b_scale_thread_copy_step.At(Number<2>{}));
405 }
406 else
407 {
408 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
409 b_scale_thread_copy_step.At(Number<1>{}));
410 }
411
412 // Initialize C
413 c_thread_buf.Clear();
414
415 // main body
416 if constexpr(HasMainLoop)
417 {
418 index_t i = 0;
419 // This hot loop has two legacy loopover, to implement the double local buffer strategy
420 do
421 {
422 auto LoopFunc = [&](auto lds_read_buf,
423 auto lds_read_reg_buf,
424 auto lds_write_buf,
425 auto mfma_reg_buf) {
427
428 static_for<0, KRepeat, 1>{}([&](auto k) {
429 static_for<0, MRepeat, 1>{}([&](auto m0) {
432 a_block_buf.At(lds_read_buf),
434 make_tuple(m0, I0, k, I0),
435 a_thread_bufs(lds_read_reg_buf));
436 });
437 static_for<0, NRepeat, 1>{}([&](auto n0) {
440 b_block_buf.At(lds_read_buf),
441 b_scale_thread_bufs(lds_read_buf)[n0],
443 make_tuple(n0, I0, k, I0),
444 b_thread_bufs(lds_read_reg_buf));
445 });
446 });
447
448 // B scale copy
449 static_for<0, NRepeat, 1>{}([&](auto n0) {
450 b_scale_thread_copy.Run(b_scale_grid_desc,
451 b_scale_grid_buf,
452 b_scale_thread_desc,
453 make_tuple(n0, I0),
454 b_scale_thread_bufs(lds_read_reg_buf));
455
456 b_scale_thread_copy.MoveSrcSliceWindow(
457 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
458 });
459
460 if((i + 4 + mfma_reg_buf.value) % num_loop_per_scale == 0)
461 {
462 b_scale_thread_copy.MoveSrcSliceWindow(
463 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
464 }
465 else
466 {
467 b_scale_thread_copy.MoveSrcSliceWindow(
468 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
469 }
470
471 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
472 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
473
474 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
475 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
476
477 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
478 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
479
480 static_for<0, KRepeat, 1>{}([&](auto k0) {
481 static_for<0, MRepeat, 1>{}([&](auto m0) {
482 static_for<0, NRepeat, 1>{}([&](auto n0) {
485
486 static_for<0, KPack, 1>{}([&](auto ik) {
487 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
488 a_thread_bufs[mfma_reg_buf]
489 [Number<a_thread_desc_.CalculateOffset(
490 make_tuple(m0, I0, k0, ik))>{}];
491 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
492 b_thread_bufs[mfma_reg_buf]
493 [Number<b_thread_desc_.CalculateOffset(
494 make_tuple(n0, I0, k0, ik))>{}];
495 });
496
497 using mfma_input_type =
499 xdlops_gemm.K1PerXdlops>::type;
500
501 constexpr index_t c_offset =
502 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
503
504 xdlops_gemm.Run(
505 a_thread_vec.template AsType<mfma_input_type>(),
506 b_thread_vec.template AsType<mfma_input_type>(),
507 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
508 });
509 });
510 });
511
513 };
514
515 LoopFunc(I1, I1, I0, I0);
516 LoopFunc(I0, I0, I1, I1);
517
518 i += HotloopUnroll;
519 } while(i < (num_loop - PrefetchStages));
520 }
521
522 auto ReadWriteCompFunc = [&](auto lds_read_buf,
523 auto lds_read_reg_buf,
524 auto lds_write_buf,
525 auto mfma_reg_buf) {
527
528 static_for<0, KRepeat, 1>{}([&](auto k) {
529 static_for<0, MRepeat, 1>{}([&](auto m0) {
532 a_block_buf.At(lds_read_buf),
534 make_tuple(m0, I0, k, I0),
535 a_thread_bufs(lds_read_reg_buf));
536 });
537 static_for<0, NRepeat, 1>{}([&](auto n0) {
540 b_block_buf.At(lds_read_buf),
541 b_scale_thread_bufs(lds_read_buf)[n0],
543 make_tuple(n0, I0, k, I0),
544 b_thread_bufs(lds_read_reg_buf));
545 });
546 });
547
548 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
549 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
550
551 static_for<0, KRepeat, 1>{}([&](auto k0) {
552 static_for<0, MRepeat, 1>{}([&](auto m0) {
553 static_for<0, NRepeat, 1>{}([&](auto n0) {
556
557 static_for<0, KPack, 1>{}([&](auto ik) {
558 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
559 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
560 make_tuple(m0, I0, k0, ik))>{}];
561 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
562 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
563 make_tuple(n0, I0, k0, ik))>{}];
564 });
565
566 using mfma_input_type =
567 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
568
569 constexpr index_t c_offset =
570 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
571
572 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
573 b_thread_vec.template AsType<mfma_input_type>(),
574 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
575 });
576 });
577 });
578
580 };
581
582 auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
584
585 static_for<0, KRepeat, 1>{}([&](auto k) {
586 static_for<0, MRepeat, 1>{}([&](auto m0) {
589 a_block_buf.At(lds_read_buf),
591 make_tuple(m0, I0, k, I0),
592 a_thread_bufs(lds_read_reg_buf));
593 });
594 static_for<0, NRepeat, 1>{}([&](auto n0) {
597 b_block_buf.At(lds_read_buf),
598 b_scale_thread_bufs(lds_read_buf)[n0],
600 make_tuple(n0, I0, k, I0),
601 b_thread_bufs(lds_read_reg_buf));
602 });
603 });
604
605 static_for<0, KRepeat, 1>{}([&](auto k0) {
606 static_for<0, MRepeat, 1>{}([&](auto m0) {
607 static_for<0, NRepeat, 1>{}([&](auto n0) {
610
611 static_for<0, KPack, 1>{}([&](auto ik) {
612 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
613 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
614 make_tuple(m0, I0, k0, ik))>{}];
615 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
616 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
617 make_tuple(n0, I0, k0, ik))>{}];
618 });
619
620 using mfma_input_type =
621 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
622
623 constexpr index_t c_offset =
624 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
625
626 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
627 b_thread_vec.template AsType<mfma_input_type>(),
628 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
629 });
630 });
631 });
632
634 };
635
636 auto CompFunc = [&](auto mfma_reg_buf) {
637 static_for<0, KRepeat, 1>{}([&](auto k0) {
638 static_for<0, MRepeat, 1>{}([&](auto m0) {
639 static_for<0, NRepeat, 1>{}([&](auto n0) {
642
643 static_for<0, KPack, 1>{}([&](auto ik) {
644 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
645 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
646 make_tuple(m0, I0, k0, ik))>{}];
647 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
648 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
649 make_tuple(n0, I0, k0, ik))>{}];
650 });
651
652 using mfma_input_type =
653 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
654
655 constexpr index_t c_offset =
656 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
657
658 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
659 b_thread_vec.template AsType<mfma_input_type>(),
660 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
661 });
662 });
663 });
664 };
665
666 // tail
667 if constexpr(TailNum == TailNumber::Odd)
668 {
669 ReadWriteCompFunc(I1, I1, I0, I0);
670 ReadCompFunc(I0, I0, I1);
671 CompFunc(I0);
672 }
673 else if constexpr(TailNum == TailNumber::Even)
674 {
675 ReadCompFunc(I1, I1, I0);
676 CompFunc(I1);
677 }
678 }
679
680 protected:
681 using Base::a_thread_copy_;
682 using Base::a_thread_desc_;
683 using Base::b_thread_copy_;
684 using Base::b_thread_desc_;
685 using Base::c_thread_desc_;
686};
687
688} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp:259
Definition blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10