device_gemm_multiple_abd_xdl_cshuffle.hpp Source File

device_gemm_multiple_abd_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_abd_xdl_cshuffle.hpp Source File
device_gemm_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename AsLayout,
25 typename BsLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename AsDataType,
29 typename BsDataType,
30 typename GemmAccDataType,
31 typename CShuffleDataType,
32 typename DsDataType,
33 typename CDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t NumGemmKPrefetchStage,
39 index_t BlockSize,
40 index_t MPerBlock,
41 index_t NPerBlock,
42 index_t KPerBlock,
43 index_t AK1,
44 index_t BK1,
45 index_t MPerXDL,
46 index_t NPerXDL,
47 index_t MXdlPerWave,
48 index_t NXdlPerWave,
49 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
52 index_t ABlockTransferSrcVectorDim,
53 index_t ABlockTransferSrcScalarPerVector,
54 index_t ABlockTransferDstScalarPerVector_AK1,
55 bool ABlockLdsExtraM,
56 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 index_t BBlockTransferSrcVectorDim,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t BBlockTransferDstScalarPerVector_BK1,
62 bool BBlockLdsExtraN,
63 index_t CShuffleMXdlPerWavePerShuffle,
64 index_t CShuffleNXdlPerWavePerShuffle,
65 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69 typename ComputeTypeA = CDataType,
70 typename ComputeTypeB = ComputeTypeA>
72 BsLayout,
73 DsLayout,
74 CLayout,
75 AsDataType,
76 BsDataType,
77 DsDataType,
78 CDataType,
79 AElementwiseOperation,
80 BElementwiseOperation,
81 CElementwiseOperation>
82{
84 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
85 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
86
87 static constexpr index_t NumATensor = AsDataType::Size();
88 static constexpr index_t NumBTensor = BsDataType::Size();
89 static constexpr index_t NumDTensor = DsDataType::Size();
90
93
94 // GridwiseGemm
95 template <index_t NXdlPerWave_>
97 ALayout,
98 BLayout,
99 CLayout,
100 AsDataType,
101 BsDataType,
102 GemmAccDataType,
103 CShuffleDataType,
104 DsDataType,
105 CDataType,
106 AElementwiseOperation,
107 BElementwiseOperation,
108 CElementwiseOperation,
109 GemmSpec,
110 BlockSize,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 AK1,
115 BK1,
116 MPerXDL,
117 NPerXDL,
118 MXdlPerWave,
119 NXdlPerWave_,
120 ABlockTransferThreadClusterLengths_AK0_M_AK1,
121 ABlockTransferThreadClusterArrangeOrder,
122 ABlockTransferSrcAccessOrder,
123 ABlockTransferSrcVectorDim,
124 ABlockTransferSrcScalarPerVector,
125 ABlockTransferDstScalarPerVector_AK1,
126 false,
127 ABlockLdsExtraM,
128 BBlockTransferThreadClusterLengths_BK0_N_BK1,
129 BBlockTransferThreadClusterArrangeOrder,
130 BBlockTransferSrcAccessOrder,
131 BBlockTransferSrcVectorDim,
132 BBlockTransferSrcScalarPerVector,
133 BBlockTransferDstScalarPerVector_BK1,
134 false,
135 BBlockLdsExtraN,
136 CShuffleMXdlPerWavePerShuffle,
137 CShuffleNXdlPerWavePerShuffle,
138 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
139 CShuffleBlockTransferScalarPerVector_NPerBlock,
140 BlkGemmPipeSched,
141 BlkGemmPipelineVer,
142 ComputeTypeA,
143 ComputeTypeB>;
146
147 using Argument = typename GridwiseGemm64::Argument;
148
149 // Invoker
150 struct Invoker : public BaseInvoker
151 {
152 template <typename GridwiseGemm>
153 float RunImp(const typename GridwiseGemm::Argument& arg,
154 const StreamConfig& stream_config = StreamConfig{})
155 {
156 if(stream_config.log_level_ > 0)
157 {
158 arg.Print();
159 }
160
161 if(!GridwiseGemm::CheckValidity(arg))
162 {
163 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
164 }
165
166 index_t gdx, gdy, gdz;
167 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
168
169 float ave_time = 0;
170
171 index_t k_grain = arg.KBatch * KPerBlock;
172 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
173
174 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
175
176 const auto Run = [&](const auto& kernel) {
177 if(arg.KBatch > 1)
178 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
179 0,
180 arg.M * arg.N * sizeof(CDataType),
181 stream_config.stream_id_));
182
183 ave_time = launch_and_time_kernel(
184 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
185 };
186
187 constexpr index_t minimum_occupancy =
188 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
189
190 if(has_main_k_block_loop)
191 {
192 // Tail number always full
193 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
194 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
195 {
196#if 0
197 if(arg.KBatch > 1)
198 {
199 const auto kernel =
200 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
201 true,
203 minimum_occupancy>;
204 Run(kernel);
205 }
206 else
207#endif
208 {
209 const auto kernel =
210 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
211 true,
213 minimum_occupancy>;
214 Run(kernel);
215 }
216 }
217 // Tail number could be One to Seven
218 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
219 {
220#if 0
221 if(arg.KBatch > 1)
222 {
223 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
224 {
225 const auto kernel =
226 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
227 true,
229 minimum_occupancy,
231 Run(kernel);
232 }
233 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
235 {
236 const auto kernel =
237 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
238 true,
240 minimum_occupancy,
242 Run(kernel);
243 }
244
245 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
246 {
247 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
248 {
249 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
250 GridwiseGemm,
251 true,
253 minimum_occupancy,
255 Run(kernel);
256 }
257 }
258
259 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
260 {
261 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
263 {
264 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
265 GridwiseGemm,
266 true,
268 minimum_occupancy,
270 Run(kernel);
271 }
272 }
273
274 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
275 {
276 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
278 {
279 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
280 GridwiseGemm,
281 true,
283 minimum_occupancy,
285 Run(kernel);
286 }
287 }
288
289 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
290 {
291 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
293 {
294 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
295 GridwiseGemm,
296 true,
298 minimum_occupancy,
300 Run(kernel);
301 }
302 }
303
304 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
305 {
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
307 {
308 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
309 GridwiseGemm,
310 true,
312 minimum_occupancy,
314 Run(kernel);
315 }
316 }
317
318 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
319 {
320 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
322 {
323 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
324 GridwiseGemm,
325 true,
327 minimum_occupancy,
329 Run(kernel);
330 }
331 }
332 }
333 else
334#endif
335 {
336 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
337 {
338 const auto kernel =
339 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
340 true,
342 minimum_occupancy,
344 Run(kernel);
345 }
346 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
348 {
349 const auto kernel =
350 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
351 true,
353 minimum_occupancy,
355 Run(kernel);
356 }
357
358 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
359 {
360 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
361 {
362 const auto kernel =
363 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
364 true,
366 minimum_occupancy,
368 Run(kernel);
369 }
370 }
371
372 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
373 {
374 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
376 {
377 const auto kernel =
378 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
379 true,
381 minimum_occupancy,
383 Run(kernel);
384 }
385 }
386
387 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
388 {
389 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
391 {
392 const auto kernel =
393 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
394 true,
396 minimum_occupancy,
398 Run(kernel);
399 }
400 }
401
402 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
403 {
404 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
406 {
407 const auto kernel =
408 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
409 true,
411 minimum_occupancy,
413 Run(kernel);
414 }
415 }
416
417 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
418 {
419 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
420 {
421 const auto kernel =
422 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
423 true,
425 minimum_occupancy,
427 Run(kernel);
428 }
429 }
430
431 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
432 {
433 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
435 {
436 const auto kernel =
437 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
438 true,
440 minimum_occupancy,
442 Run(kernel);
443 }
444 }
445 }
446 }
447 // Tail number could be Odd or Even
448 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
449 {
450#if 0
451 if(arg.KBatch > 1)
452 {
453 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
454 {
455 const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
456 GridwiseGemm,
457 true,
459 minimum_occupancy,
461 Run(kernel);
462 }
463 else
464 {
465 const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
466 GridwiseGemm,
467 true,
469 minimum_occupancy,
471 Run(kernel);
472 }
473 }
474 else
475#endif
476 {
477 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
478 {
479 const auto kernel =
481 true,
483 minimum_occupancy,
485 Run(kernel);
486 }
487 else
488 {
489 const auto kernel =
491 true,
493 minimum_occupancy,
495 Run(kernel);
496 }
497 }
498 }
499 else
500 {
501#if 0
502 if(arg.KBatch > 1)
503 {
504 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
505 {
506 const auto kernel =
507 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
508 true,
510 minimum_occupancy,
512 Run(kernel);
513 }
514 else
515 {
516 const auto kernel =
517 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
518 true,
520 minimum_occupancy,
522 Run(kernel);
523 }
524 }
525 else
526#endif
527 {
528 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
529 {
530 const auto kernel =
531 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
532 true,
534 minimum_occupancy,
536 Run(kernel);
537 }
538 else
539 {
540 const auto kernel =
541 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
542 true,
544 minimum_occupancy,
546 Run(kernel);
547 }
548 }
549 }
550 }
551 else
552 {
553 // Tail number always 1
554 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
555 {
556#if 0
557 if(arg.KBatch > 1)
558 {
559 const auto kernel =
560 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
561 false,
563 minimum_occupancy>;
564 Run(kernel);
565 }
566 else
567#endif
568 {
569 const auto kernel =
570 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
571 false,
573 minimum_occupancy>;
574 Run(kernel);
575 }
576 }
577 }
578
579 return ave_time;
580 }
581
583
584 // polymorphic
585 float Run(const BaseArgument* p_arg,
586 const StreamConfig& stream_config = StreamConfig{}) override
587 {
588 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
589 }
590 };
591
592 static constexpr bool IsValidCompilationParameter()
593 {
594 // TODO: properly implement this check
595 return true;
596 }
597
598 static bool IsSupportedArgument(const Argument& arg)
599 {
601 {
602 return false;
603 }
604 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
605 GemmSpec == GemmSpecialization::NKPadding ||
606 GemmSpec == GemmSpecialization::MNKPadding ||
607 GemmSpec == GemmSpecialization::KPadding))
608 {
609 return false;
610 }
611
612 if(get_warp_size() == 64)
613 {
614 if constexpr(NXdlPerWave64 > 0)
615 {
617 }
618 }
619 else
620 {
621 if constexpr(NXdlPerWave32 > 0)
622 {
624 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
625 }
626 }
627 return false;
628 }
629
630 // polymorphic
631 bool IsSupportedArgument(const BaseArgument* p_arg) override
632 {
633 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
634 }
635
636 static auto MakeArgument(std::array<const void*, NumATensor> p_as,
637 std::array<const void*, NumBTensor> p_bs,
638 std::array<const void*, NumDTensor> p_ds,
639 void* p_e,
640 index_t MRaw,
641 index_t NRaw,
642 index_t KRaw,
643 std::array<index_t, NumATensor> StrideAs,
644 std::array<index_t, NumBTensor> StrideBs,
645 std::array<index_t, NumDTensor> StrideDs,
646 index_t StrideE,
647 AElementwiseOperation a_element_op,
648 BElementwiseOperation b_element_op,
649 CElementwiseOperation c_element_op)
650 {
651
652 static_for<0, NumATensor, 1>{}([&](auto i) {
653 using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
654
655 static_assert(is_same<ALayout_, ALayout>::value, "");
656 });
657
658 static_for<0, NumBTensor, 1>{}([&](auto i) {
659 using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
660
661 static_assert(is_same<BLayout_, BLayout>::value, "");
662 });
663
664 static_for<0, NumDTensor, 1>{}([&](auto i) {
665 using DLayout_ = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
666
667 static_assert(is_same<DLayout_, CLayout>::value, "");
668 });
669
670 return Argument{p_as,
671 p_bs,
672 p_ds,
673 p_e,
674 MRaw,
675 NRaw,
676 KRaw,
677 StrideAs,
678 StrideBs,
679 StrideDs,
680 StrideE,
681 1,
682 a_element_op,
683 b_element_op,
684 c_element_op};
685 }
686
687 static auto MakeInvoker() { return Invoker{}; }
688
689 // polymorphic
690 std::unique_ptr<BaseArgument> MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
691 std::array<const void*, NumBTensor> p_bs,
692 std::array<const void*, NumDTensor> p_ds,
693 void* p_e,
694 index_t MRaw,
695 index_t NRaw,
696 index_t KRaw,
697 std::array<ck::index_t, NumATensor> StrideAs,
698 std::array<ck::index_t, NumBTensor> StrideBs,
699 std::array<ck::index_t, NumDTensor> StrideDs,
700 index_t StrideE,
701 AElementwiseOperation a_element_op,
702 BElementwiseOperation b_element_op,
703 CElementwiseOperation c_element_op) override
704 {
705 return std::make_unique<Argument>(p_as,
706 p_bs,
707 p_ds,
708 p_e,
709 MRaw,
710 NRaw,
711 KRaw,
712 StrideAs,
713 StrideBs,
714 StrideDs,
715 StrideE,
716 1,
717 a_element_op,
718 b_element_op,
719 c_element_op);
720 }
721
722 // polymorphic
723 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
724 {
725 return std::make_unique<Invoker>(Invoker{});
726 }
727
728 // polymorphic
729 std::string GetTypeString() const override
730 {
731 auto str = std::stringstream();
732
733 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
736
737 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
743
744 // clang-format off
745 str << "DeviceGemmXdlUniversal"
746 << "<"
747 << getGemmSpecializationString(GemmSpec) << ", "
748 << std::string(ALayout::name)[0]
749 << std::string(BLayout::name)[0]
750 << std::string(CLayout::name)[0]
751 << ">"
752 << " BlkSize: "
753 << BlockSize << ", "
754 << "BlkTile: "
755 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
756 << "WaveTile: "
757 << MPerXDL<<"x"<<NPerXDL << ", "
758 << "WaveMap: "
759 << MXdlPerWave<<"x" << NXdlPerWave<<", "
760 << "VmemReadVec: "
761 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
762 << "BlkGemmPipelineScheduler: "
763 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
764 << "BlkGemmPipelineVersion: "
765 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
766 << "BlkGemmPipelinePrefetchStages: "
767 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
768 // clang-format on
769
770 return str.str();
771 }
772};
773
774} // namespace device
775} // namespace tensor_operation
776} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:151
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:153
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:585
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:82
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:723
std::string GetTypeString() const override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:729
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:147
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, AsDataType, BsDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:96
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:690
static auto MakeArgument(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, std::array< index_t, NumATensor > StrideAs, std::array< index_t, NumBTensor > StrideBs, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:636
remove_cvref_t< tuple_element_t< 0, AsLayout > > ALayout
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:91
static constexpr index_t NumATensor
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:87
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:598
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:85
static constexpr index_t NumBTensor
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:88
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:592
remove_cvref_t< tuple_element_t< 0, BsLayout > > BLayout
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:92
static auto MakeInvoker()
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:687
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:631
static constexpr index_t NumDTensor
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:89
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:84
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:145
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_abd_xdl_cshuffle.hpp:144
Definition device_gemm_multiple_abd.hpp:34