device_gemm_xdl_cshuffle_streamk_v3.hpp Source File

device_gemm_xdl_cshuffle_streamk_v3.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_streamk_v3.hpp Source File
device_gemm_xdl_cshuffle_streamk_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename CLayout,
27 typename ADataType,
28 typename BDataType,
29 typename CDataType,
30 typename GemmAccDataType,
31 typename CShuffleDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
35 GemmSpecialization GemmSpec,
36 index_t BlockSize,
37 index_t MPerBlock,
38 index_t NPerBlock,
39 index_t KPerBlock,
40 index_t AK1,
41 index_t BK1,
42 index_t MPerXDL,
43 index_t NPerXDL,
44 index_t MXdlPerWave,
45 index_t NXdlPerWave,
46 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
47 typename ABlockTransferThreadClusterArrangeOrder,
48 typename ABlockTransferSrcAccessOrder,
49 index_t ABlockTransferSrcVectorDim,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t ABlockTransferDstScalarPerVector_AK1,
52 bool ABlockLdsExtraM,
53 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
54 typename BBlockTransferThreadClusterArrangeOrder,
55 typename BBlockTransferSrcAccessOrder,
56 index_t BBlockTransferSrcVectorDim,
57 index_t BBlockTransferSrcScalarPerVector,
58 index_t BBlockTransferDstScalarPerVector_BK1,
59 bool BBlockLdsExtraN,
60 index_t CShuffleMXdlPerWavePerShuffle,
61 index_t CShuffleNXdlPerWavePerShuffle,
62 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
66 typename ComputeTypeA = CDataType,
67 typename ComputeTypeB = ComputeTypeA>
69 BLayout,
70 CLayout,
71 ADataType,
72 BDataType,
73 CDataType,
74 AElementwiseOperation,
75 BElementwiseOperation,
76 CElementwiseOperation>
77{
79 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
80 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
81
82 // GridwiseGemm
83 template <index_t NXdlPerWave_>
85 ALayout,
86 BLayout,
87 CLayout,
88 ADataType,
89 BDataType,
90 GemmAccDataType,
91 CShuffleDataType,
92 CDataType,
93 AElementwiseOperation,
94 BElementwiseOperation,
95 CElementwiseOperation,
96 GemmSpec,
97 BlockSize,
98 MPerBlock,
99 NPerBlock,
100 KPerBlock,
101 AK1,
102 BK1,
103 MPerXDL,
104 NPerXDL,
105 MXdlPerWave,
106 NXdlPerWave_,
107 ABlockTransferThreadClusterLengths_AK0_M_AK1,
108 ABlockTransferThreadClusterArrangeOrder,
109 ABlockTransferSrcAccessOrder,
110 ABlockTransferSrcVectorDim,
111 ABlockTransferSrcScalarPerVector,
112 ABlockTransferDstScalarPerVector_AK1,
113 false,
114 ABlockLdsExtraM,
115 BBlockTransferThreadClusterLengths_BK0_N_BK1,
116 BBlockTransferThreadClusterArrangeOrder,
117 BBlockTransferSrcAccessOrder,
118 BBlockTransferSrcVectorDim,
119 BBlockTransferSrcScalarPerVector,
120 BBlockTransferDstScalarPerVector_BK1,
121 false,
122 BBlockLdsExtraN,
123 CShuffleMXdlPerWavePerShuffle,
124 CShuffleNXdlPerWavePerShuffle,
125 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
126 CShuffleBlockTransferScalarPerVector_NPerBlock,
127 BlkGemmPipeSched,
128 BlkGemmPipelineVer,
129 ComputeTypeA,
130 ComputeTypeB>;
133
134 using Argument = typename GridwiseGemm64::Argument;
135 //
136
137 // Invoker
138 struct Invoker : public BaseInvoker
139 {
140 template <typename GridwiseGemm>
141 float RunImp(const typename GridwiseGemm::Argument& arg,
142 const StreamConfig& stream_config = StreamConfig{})
143 {
144
145 if(stream_config.log_level_ > 0)
146 {
147 arg.Print();
148 }
149
150 if(!GridwiseGemm::CheckValidity(arg))
151 {
152 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
153 }
154
155 float ave_time = 0;
156
157 index_t k_grain = KPerBlock;
158 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
159
160 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
161
162 if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
163 {
164
165 hip_check_error(hipMemsetAsync(
166 arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
167 }
168
169 const auto Run = [&](const auto& kernel) {
170 dim3 grid_dim;
171 if(arg.Grid_size < 0)
172 {
173 int occupancy, num_cu;
174 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
175 &occupancy, kernel, BlockSize, 0));
176 hipDeviceProp_t dev_prop;
177 hipDevice_t dev;
178 hip_check_error(hipGetDevice(&dev));
179 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
180 num_cu = dev_prop.multiProcessorCount;
181 arg.Grid_size = num_cu * occupancy;
182 grid_dim = arg.Grid_size;
183 }
184 else
185 grid_dim = arg.Grid_size;
186
187 if(stream_config.flush_cache)
188 {
189 auto arg_ = arg;
191 arg_,
192 stream_config.rotating_count,
193 arg_.M * arg_.K * sizeof(ADataType),
194 arg_.K * arg_.N * sizeof(BDataType));
195 rotating_mem.Print();
196
197 auto run_flush_cache = [&]() {
198 // flush icache
200 // rotating mem
201 rotating_mem.Next();
202 };
203
205 stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
206 }
207 else
208 {
209
210 if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
211 {
212 ave_time = launch_and_time_kernel(
213 stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
214 }
215 else if(arg.reduction_strategy == StreamKReductionStrategy::Reduction)
216 {
217 char* workspace_semaphore =
218 reinterpret_cast<char*>(arg.p_workspace_) +
219 arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
220 sizeof(GemmAccDataType));
221 auto preprocess = [&]() {
222 hipError_t status = hipMemsetAsync(
223 workspace_semaphore,
224 0,
225 // sizeof(uint32_t),
226 arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
227 stream_config.stream_id_);
228
229 // Check the status
230 hip_check_error(status);
231 };
232
234 stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
235 }
236 }
237 };
238
239 constexpr index_t minimum_occupancy =
240 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
241
242 if(has_main_k_block_loop)
243 {
244 // Tail number always full
245 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
246 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
247 {
248
249 const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
250 true,
252 minimum_occupancy>;
253
254 Run(kernel);
255 }
256 // Tail number could be One to Seven
257 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
258 {
259
260 {
261 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
262 {
263 const auto kernel =
264 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
265 true,
267 minimum_occupancy,
269 Run(kernel);
270 }
271 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
273 {
274 const auto kernel =
275 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
276 true,
278 minimum_occupancy,
280 Run(kernel);
281 }
282
283 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
284 {
285 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
286 {
287 const auto kernel =
288 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
289 true,
291 minimum_occupancy,
293 Run(kernel);
294 }
295 }
296
297 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
298 {
299 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
301 {
302 const auto kernel =
303 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
304 true,
306 minimum_occupancy,
308 Run(kernel);
309 }
310 }
311
312 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
313 {
314 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
316 {
317 const auto kernel =
318 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
319 true,
321 minimum_occupancy,
323 Run(kernel);
324 }
325 }
326
327 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
328 {
329 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
331 {
332 const auto kernel =
333 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
334 true,
336 minimum_occupancy,
338 Run(kernel);
339 }
340 }
341
342 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
343 {
344 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
345 {
346 const auto kernel =
347 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
348 true,
350 minimum_occupancy,
352 Run(kernel);
353 }
354 }
355
356 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
357 {
358 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
360 {
361 const auto kernel =
362 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
363 true,
365 minimum_occupancy,
367 Run(kernel);
368 }
369 }
370 }
371 }
372 // Tail number could be Odd or Even
373 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
374 {
375
376 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
377 {
378 const auto kernel =
380 true,
382 minimum_occupancy,
384 Run(kernel);
385 }
386 else
387 {
388 const auto kernel =
390 true,
392 minimum_occupancy,
394 Run(kernel);
395 }
396 }
397 else
398 {
399
400 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
401 {
402 const auto kernel =
403 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
404 true,
406 minimum_occupancy,
408 Run(kernel);
409 }
410 else
411 {
412 const auto kernel =
413 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
414 true,
416 minimum_occupancy,
418 Run(kernel);
419 }
420 }
421 }
422 else
423 {
424 // Tail number always 1
425 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
426 {
427
428 const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
429 false,
431 minimum_occupancy>;
432 Run(kernel);
433 }
434 }
435
436 return ave_time;
437 }
438
440
441 // polymorphic
442 float Run(const BaseArgument* p_arg,
443 const StreamConfig& stream_config = StreamConfig{}) override
444 {
445 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
446 }
447 };
448
449 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
450 {
451 const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
452 if(p_arg->reduction_strategy == StreamKReductionStrategy::Reduction)
453 {
454 return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
455 }
456 else
457 {
458 return 0;
459 }
460 }
461
463 void* p_workspace,
464 const StreamConfig& = StreamConfig{}) const override
465 {
466 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
467
468 pArg_->p_workspace_ = p_workspace;
469 }
470
471 static constexpr bool IsValidCompilationParameter()
472 {
473 // TODO: properly implement this check
474 return true;
475 }
476
477 static bool IsSupportedArgument(const Argument& arg)
478 {
479 // gfx11 doesn't support float atomic
481 {
482 return false;
483 }
485 {
486 return false;
487 }
488 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> &&
489 arg.Streamk_sel > 0)
490 {
491 return false;
492 }
493 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
494 GemmSpec == GemmSpecialization::NKPadding ||
495 GemmSpec == GemmSpecialization::MNKPadding ||
496 GemmSpec == GemmSpecialization::KPadding))
497 {
498 return false;
499 }
500
501 if(get_warp_size() == 64)
502 {
503 if constexpr(NXdlPerWave64 > 0)
504 {
506 }
507 }
508 else
509 {
510 if constexpr(NXdlPerWave32 > 0)
511 {
513 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
514 }
515 }
516 return false;
517 }
518
519 // polymorphic
520 bool IsSupportedArgument(const BaseArgument* p_arg) override
521 {
522 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
523 }
524 template <typename GridwiseGemm, bool IsValid>
525 static auto
526 MakeArgumentImp(const ADataType* p_a,
527 const BDataType* p_b,
528 CDataType* p_c,
529 index_t M,
530 index_t N,
531 index_t K,
532 index_t StrideA,
533 index_t StrideB,
534 index_t StrideC,
535 index_t streamk_sel,
536 index_t Grid_size,
537 AElementwiseOperation,
538 BElementwiseOperation,
539 CElementwiseOperation,
541 {
542
543 constexpr index_t minimum_occupancy =
544 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
545 index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock;
546
547 int occupancy = 1, num_cu = 1;
548 const auto calculate_grid_size = [&](const auto& kernel) {
550 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
551 hipDeviceProp_t dev_prop;
552 hipDevice_t dev;
553 hip_check_error(hipGetDevice(&dev));
554 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
555 num_cu = dev_prop.multiProcessorCount;
556 Grid_size = num_cu * occupancy;
557 };
558
559 if constexpr(IsValid)
560 {
561 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
562 if(has_main_k_block_loop)
563 {
564 // Tail number always full
565 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
566 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
567 {
568
569 const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
570 true,
572 minimum_occupancy>;
573 calculate_grid_size(kernel);
574 }
575 // Tail number could be One to Seven
576 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
577 {
578
579 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
580 {
581 const auto kernel =
582 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
583 true,
585 minimum_occupancy,
587 calculate_grid_size(kernel);
588 }
589 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
590 {
591 const auto kernel =
592 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
593 true,
595 minimum_occupancy,
597 calculate_grid_size(kernel);
598 }
599
600 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
601 {
602 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
603 {
604 const auto kernel =
605 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
606 true,
608 minimum_occupancy,
610 calculate_grid_size(kernel);
611 }
612 }
613
614 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
615 {
616 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
617 {
618 const auto kernel =
619 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
620 true,
622 minimum_occupancy,
624 calculate_grid_size(kernel);
625 }
626 }
627
628 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
629 {
630 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
631 {
632 const auto kernel =
633 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
634 true,
636 minimum_occupancy,
638 calculate_grid_size(kernel);
639 }
640 }
641
642 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
643 {
644 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
645 {
646 const auto kernel =
647 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
648 true,
650 minimum_occupancy,
652 calculate_grid_size(kernel);
653 }
654 }
655
656 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
657 {
658 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
659 {
660 const auto kernel =
661 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
662 true,
664 minimum_occupancy,
666 calculate_grid_size(kernel);
667 }
668 }
669
670 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
671 {
672 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
673 {
674 const auto kernel =
675 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
676 true,
678 minimum_occupancy,
680 calculate_grid_size(kernel);
681 }
682 }
683 }
684 // Tail number could be Odd or Even
685 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
686 {
687
688 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
689 {
690 const auto kernel =
692 true,
694 minimum_occupancy,
696 calculate_grid_size(kernel);
697 }
698 else
699 {
700 const auto kernel =
702 true,
704 minimum_occupancy,
706 calculate_grid_size(kernel);
707 }
708 }
709 else
710 {
711
712 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
713 {
714 const auto kernel =
715 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
716 true,
718 minimum_occupancy,
720 calculate_grid_size(kernel);
721 }
722 else
723 {
724 const auto kernel =
725 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
726 true,
728 minimum_occupancy,
730 calculate_grid_size(kernel);
731 }
732 }
733 }
734 else
735 {
736 // Tail number always 1
737 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
738 {
739
740 const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
741 false,
743 minimum_occupancy>;
744 calculate_grid_size(kernel);
745 }
746 }
747 }
748
749 return Argument{p_a,
750 p_b,
751 p_c,
752 M,
753 N,
754 K,
755 StrideA,
756 StrideB,
757 StrideC,
758 streamk_sel,
759 Grid_size,
760 reduction_strategy};
761 }
762
763 static auto
764 MakeArgument(const ADataType* p_a,
765 const BDataType* p_b,
766 CDataType* p_c,
767 index_t M,
768 index_t N,
769 index_t K,
770 index_t StrideA,
771 index_t StrideB,
772 index_t StrideC,
773 index_t streamk_sel,
774 index_t Grid_size,
775 AElementwiseOperation a_op,
776 BElementwiseOperation b_op,
777 CElementwiseOperation c_op,
779 {
780 if(get_warp_size() == 64)
781 {
782 constexpr bool IsValid = NXdlPerWave64 > 0;
784 p_b,
785 p_c,
786 M,
787 N,
788 K,
789 StrideA,
790 StrideB,
791 StrideC,
792 streamk_sel,
793 Grid_size,
794 a_op,
795 b_op,
796 c_op,
797 reduction_strategy);
798 }
799 else
800 {
801 constexpr bool IsValid = NXdlPerWave32 > 0;
803 p_b,
804 p_c,
805 M,
806 N,
807 K,
808 StrideA,
809 StrideB,
810 StrideC,
811 streamk_sel,
812 Grid_size,
813 a_op,
814 b_op,
815 c_op,
816 reduction_strategy);
817 }
818 }
819 static auto MakeInvoker() { return Invoker{}; }
820
821 // polymorphic
822 std::unique_ptr<BaseArgument> MakeArgumentPointer(
823 const void* p_a,
824 const void* p_b,
825 void* p_c,
826 index_t M,
827 index_t N,
828 index_t K,
829 index_t StrideA,
830 index_t StrideB,
831 index_t StrideC,
832 index_t streamk_sel,
833 index_t Grid_size,
834 AElementwiseOperation,
835 BElementwiseOperation,
836 CElementwiseOperation,
838 {
839 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
840 static_cast<const BDataType*>(p_b),
841 static_cast<CDataType*>(p_c),
842 M,
843 N,
844 K,
845 StrideA,
846 StrideB,
847 StrideC,
848 streamk_sel,
849 Grid_size,
850 reduction_strategy);
851 }
852
853 // polymorphic
854 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
855 {
856 return std::make_unique<Invoker>(Invoker{});
857 }
858
859 // polymorphic
860 std::string GetTypeString() const override
861 {
862 auto str = std::stringstream();
863
864 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
867
868 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
874
875 // clang-format off
876 str << "DeviceGemmXdlUniversal"
877 << "<"
878 << getGemmSpecializationString(GemmSpec) << ", "
879 << std::string(ALayout::name)[0]
880 << std::string(BLayout::name)[0]
881 << std::string(CLayout::name)[0]
882 << ">"
883 << " BlkSize: "
884 << BlockSize << ", "
885 << "BlkTile: "
886 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
887 << "WaveTile: "
888 << MPerXDL<<"x"<<NPerXDL << ", "
889 << "WaveMap: "
890 << MXdlPerWave<<"x" << NXdlPerWave<<", "
891 << "VmemReadVec: "
892 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
893 << "BlkGemmPipelineScheduler: "
894 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
895 << "BlkGemmPipelineVersion: "
896 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
897 << "BlkGemmPipelinePrefetchStages: "
898 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
899 // clang-format on
900
901 return str.str();
902 }
903};
904
905} // namespace device
906} // namespace tensor_operation
907} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
StreamKReductionStrategy
Definition block_to_ctile_map.hpp:1011
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:133
Definition device_base.hpp:197
Definition device_gemm_streamk_v2.hpp:23
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:139
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:442
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:141
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:77
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:80
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:131
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:449
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:134
GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:84
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:471
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation a_op, BElementwiseOperation b_op, CElementwiseOperation c_op, StreamKReductionStrategy reduction_strategy=StreamKReductionStrategy::Atomic)
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:764
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:462
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:477
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:819
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, StreamKReductionStrategy reduction_strategy=StreamKReductionStrategy::Atomic) override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:822
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:854
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:520
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:860
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:132
static auto MakeArgumentImp(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, StreamKReductionStrategy reduction_strategy=StreamKReductionStrategy::Atomic)
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:526
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:79
Definition flush_cache.hpp:299