device_gemm_xdl_layernorm_cshuffle.hpp Source File

device_gemm_xdl_layernorm_cshuffle.hpp Source File#

Composable Kernel: device_gemm_xdl_layernorm_cshuffle.hpp Source File
device_gemm_xdl_layernorm_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
25// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
26// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
27//
28// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
29// version currently has compiler issues with register spill which further causes validation
30// failures.
31//
32// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
33template <typename ALayout,
34 typename BLayout,
35 typename CLayout,
36 typename ADataType,
37 typename BDataType,
38 typename CDataType,
39 typename C0DataType,
40 typename GemmAccDataType,
41 typename CShuffleDataType,
42 typename ReduceAccDataType,
43 typename AElementwiseOperation,
44 typename BElementwiseOperation,
45 typename AccElementwiseOperation,
46 typename CElementwiseOperation,
47 GemmSpecialization GemmSpec,
48 index_t NumGemmKPrefetchStage,
49 index_t BlockSize,
50 index_t MPerBlock,
51 index_t NPerBlock,
52 index_t KPerBlock,
53 index_t AK1,
54 index_t BK1,
55 index_t MPerXDL,
56 index_t NPerXDL,
57 index_t MXdlPerWave,
58 index_t NXdlPerWave,
59 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
60 typename ABlockTransferThreadClusterArrangeOrder,
61 typename ABlockTransferSrcAccessOrder,
62 index_t ABlockTransferSrcVectorDim,
63 index_t ABlockTransferSrcScalarPerVector,
64 index_t ABlockTransferDstScalarPerVector_AK1,
65 bool ABlockLdsExtraM,
66 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
67 typename BBlockTransferThreadClusterArrangeOrder,
68 typename BBlockTransferSrcAccessOrder,
69 index_t BBlockTransferSrcVectorDim,
70 index_t BBlockTransferSrcScalarPerVector,
71 index_t BBlockTransferDstScalarPerVector_BK1,
72 bool BBlockLdsExtraN,
73 index_t CShuffleMXdlPerWavePerShuffle,
74 index_t CShuffleNXdlPerWavePerShuffle,
75 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
76 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
77 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
78 index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
81{
83
85 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
86 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
87
88 static constexpr auto I0 = Number<0>{};
89 static constexpr auto I1 = Number<1>{};
90 static constexpr auto I2 = Number<2>{};
91
92 static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
93 {
94 const auto a_grid_desc_mraw_kraw = [&]() {
96 {
98 make_tuple(StrideA, I1));
99 }
101 {
102 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
103 make_tuple(I1, StrideA));
104 }
105 }();
106
107 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
108 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
109
110 const auto MPad = M - MRaw;
111 const auto KPad = K - KRaw;
112
113 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
115 {
116 // pad both M and K
117 assert(K % AK1 == 0);
118
119 const auto AK0 = K / AK1;
120
121 const auto a_grid_desc_m_k =
122 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
124 make_right_pad_transform(KRaw, KPad)),
127
128 const auto a_grid_desc_ak0_m_ak1 =
129 transform_tensor_descriptor(a_grid_desc_m_k,
134
135 return a_grid_desc_ak0_m_ak1;
136 }
137 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
139 {
140 // pad M, but not K
141 assert(KRaw % AK1 == 0);
142
143 const auto AK0 = KRaw / AK1;
144
145 const auto a_grid_desc_ak0_m_ak1 =
146 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
148 make_right_pad_transform(MRaw, MPad)),
151
152 return a_grid_desc_ak0_m_ak1;
153 }
154 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
156 {
157 // pad K, but not M
158 assert(K % AK1 == 0);
159
160 const auto AK0 = K / AK1;
161
162 const auto a_grid_desc_m_k = transform_tensor_descriptor(
163 a_grid_desc_mraw_kraw,
167
168 const auto a_grid_desc_ak0_m_ak1 =
169 transform_tensor_descriptor(a_grid_desc_m_k,
174
175 return a_grid_desc_ak0_m_ak1;
176 }
177 else
178 {
179 // not pad M or K
180 assert(KRaw % AK1 == 0);
181
182 const auto AK0 = KRaw / AK1;
183
184 const auto a_grid_desc_ak0_m_ak1 =
185 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
190
191 return a_grid_desc_ak0_m_ak1;
192 }
193 }
194
195 static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
196 {
197 const auto b_grid_desc_nraw_kraw = [&]() {
199 {
200 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
201 make_tuple(I1, StrideB));
202 }
204 {
205 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
206 make_tuple(StrideB, I1));
207 }
208 }();
209
210 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
211 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
212
213 const auto NPad = N - NRaw;
214 const auto KPad = K - KRaw;
215
216 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
218 {
219 // pad both N and K
220 assert(K % BK1 == 0);
221
222 const auto BK0 = K / BK1;
223
224 const auto b_grid_desc_n_k =
225 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
227 make_right_pad_transform(KRaw, KPad)),
230
231 const auto b_grid_desc_bk0_n_bk1 =
232 transform_tensor_descriptor(b_grid_desc_n_k,
237
238 return b_grid_desc_bk0_n_bk1;
239 }
240 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
242 {
243 // pad N, but not K
244 assert(KRaw % BK1 == 0);
245
246 const auto BK0 = KRaw / BK1;
247
248 const auto b_grid_desc_bk0_n_bk1 =
249 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
251 make_right_pad_transform(NRaw, NPad)),
254
255 return b_grid_desc_bk0_n_bk1;
256 }
257 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
259 {
260 // pad K, but not N
261 assert(K % BK1 == 0);
262
263 const auto BK0 = K / BK1;
264
265 const auto b_grid_desc_n_k = transform_tensor_descriptor(
266 b_grid_desc_nraw_kraw,
270
271 const auto b_grid_desc_bk0_n_bk1 =
272 transform_tensor_descriptor(b_grid_desc_n_k,
277
278 return b_grid_desc_bk0_n_bk1;
279 }
280 else
281 {
282 // not pad N or K
283 assert(KRaw % BK1 == 0);
284
285 const auto BK0 = KRaw / BK1;
286
287 const auto b_grid_desc_bk0_n_bk1 =
288 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
293
294 return b_grid_desc_bk0_n_bk1;
295 }
296 }
297
298 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
299 {
300 const auto c_grid_desc_mraw_nraw = [&]() {
302 {
303 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
304 make_tuple(StrideC, I1));
305 }
307 {
308 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
309 make_tuple(I1, StrideC));
310 }
311 }();
312
313 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
314 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
315
316 const auto MPad = M - MRaw;
317 const auto NPad = N - NRaw;
318
319 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
321 {
322 // pad M and N
323 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
325 make_right_pad_transform(NRaw, NPad)),
328 }
329 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
331 {
332 // pad M, but not N
334 c_grid_desc_mraw_nraw,
338 }
339 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
341 {
342 // pad N, but not M
344 c_grid_desc_mraw_nraw,
348 }
349 else
350 {
351 // not pad M or N
352 return c_grid_desc_mraw_nraw;
353 }
354 }
355
357 {
358 const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
359
360 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
361 const auto NPad = N - NRaw;
362
363 if constexpr(GemmSpec == GemmSpecialization::NPadding ||
364 GemmSpec == GemmSpecialization::MNPadding ||
365 GemmSpec == GemmSpecialization::NKPadding ||
367 {
368 // pad N
369 return transform_tensor_descriptor(grid_desc_nraw,
373 }
374 else
375 {
376 // not pad N
377 return grid_desc_nraw;
378 }
379 }
380
383 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
384 using C0GridDesc_N = decltype(MakeGridDescriptor_N(1));
385
386 // GridwiseGemm
387 template <index_t NXdlPerWave_>
389 ADataType, // TODO: distinguish A/B datatype
390 GemmAccDataType,
391 CShuffleDataType,
392 CDataType,
393 C0DataType,
394 ReduceAccDataType,
395 AElementwiseOperation,
396 BElementwiseOperation,
397 AccElementwiseOperation,
398 CElementwiseOperation,
404 NumGemmKPrefetchStage,
405 BlockSize,
406 MPerBlock,
407 NPerBlock,
408 KPerBlock,
409 AK1,
410 BK1,
411 MPerXDL,
412 NPerXDL,
413 MXdlPerWave,
414 NXdlPerWave_,
415 ABlockTransferThreadClusterLengths_AK0_M_AK1,
416 ABlockTransferThreadClusterArrangeOrder,
417 ABlockTransferSrcAccessOrder,
418 ABlockTransferSrcVectorDim,
419 ABlockTransferSrcScalarPerVector,
420 ABlockTransferDstScalarPerVector_AK1,
421 false,
422 ABlockLdsExtraM,
423 BBlockTransferThreadClusterLengths_BK0_N_BK1,
424 BBlockTransferThreadClusterArrangeOrder,
425 BBlockTransferSrcAccessOrder,
426 BBlockTransferSrcVectorDim,
427 BBlockTransferSrcScalarPerVector,
428 BBlockTransferDstScalarPerVector_BK1,
429 false,
430 BBlockLdsExtraN,
431 CShuffleMXdlPerWavePerShuffle,
432 NXdlPerWave_,
433 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
434 CShuffleBlockTransferScalarPerVector_NPerBlock,
435 CReduceThreadClusterLengths_MPerBlock_NPerBlock,
436 CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
437 LoopSched>;
440
442
443 // Argument
444 struct Argument : public BaseArgument
445 {
446 Argument(const ADataType* p_a_grid,
447 const BDataType* p_b_grid,
448 CDataType* p_c_grid,
449 const C0DataType* p_c0_grid_add,
450 const C0DataType* p_c0_grid_bias,
451 const C0DataType* p_c0_grid_gamma,
452 const C0DataType* p_c0_grid_beta,
453 index_t MRaw,
454 index_t NRaw,
455 index_t KRaw,
456 index_t StrideA,
457 index_t StrideB,
458 index_t StrideC,
459 AElementwiseOperation a_element_op,
460 BElementwiseOperation b_element_op,
461 AccElementwiseOperation acc_element_op,
462 CElementwiseOperation c_element_op)
463 : p_a_grid_{p_a_grid},
464 p_b_grid_{p_b_grid},
465 p_c_grid_{p_c_grid},
466 p_c0_grid_bias_{p_c0_grid_bias},
467 p_c0_grid_add_{p_c0_grid_add},
468 p_c0_grid_gamma_{p_c0_grid_gamma},
469 p_c0_grid_beta_{p_c0_grid_beta},
475 a_element_op_{a_element_op},
476 b_element_op_{b_element_op},
477 acc_element_op_{acc_element_op},
478 c_element_op_{c_element_op}
479 {
480 }
481
482 // private:
483 const ADataType* p_a_grid_;
484 const BDataType* p_b_grid_;
485 CDataType* p_c_grid_;
486 const C0DataType* p_c0_grid_bias_;
487 const C0DataType* p_c0_grid_add_;
488 const C0DataType* p_c0_grid_gamma_;
489 const C0DataType* p_c0_grid_beta_;
495 AElementwiseOperation a_element_op_;
496 BElementwiseOperation b_element_op_;
497 AccElementwiseOperation acc_element_op_;
498 CElementwiseOperation c_element_op_;
499 };
500
501 // Invoker
502 struct Invoker : public BaseInvoker
503 {
505
506 template <typename GridwiseGemm>
507 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
508 {
509 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
510 {
511 std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
512 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
513 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
514 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
515
516 std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
517 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
518 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
519 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
520
521 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
522 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
523 }
524
525 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
529 {
530 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
531 }
532 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
533 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
534 arg.c_grid_desc_m_n_);
535
536 auto c0_grid_desc_nblock_nperblock =
537 GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(arg.c0_grid_desc_n_);
538 const index_t grid_size =
539 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
540
541 const auto K =
542 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
543
544 float ave_time = 0;
545
546 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
547 {
548 const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
549 GridwiseGemm,
550 ADataType, // TODO: distiguish A/B datatype
551 CDataType,
552 C0DataType,
553 AElementwiseOperation,
554 BElementwiseOperation,
555 AccElementwiseOperation,
556 CElementwiseOperation,
559 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
560 typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
562 true>;
563
564 ave_time = launch_and_time_kernel(stream_config,
565 kernel,
566 dim3(grid_size),
567 dim3(BlockSize),
568 0,
569 arg.p_a_grid_,
570 arg.p_b_grid_,
571 arg.p_c_grid_,
572 arg.p_c0_grid_bias_,
573 arg.p_c0_grid_add_,
575 arg.p_c0_grid_beta_,
576 arg.a_element_op_,
577 arg.b_element_op_,
578 arg.acc_element_op_,
579 arg.c_element_op_,
582 c_grid_desc_mblock_mperblock_nblock_nperblock,
583 c0_grid_desc_nblock_nperblock,
585 }
586 else
587 {
588 const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
589 GridwiseGemm,
590 ADataType, // TODO: distiguish A/B datatype
591 CDataType,
592 C0DataType,
593 AElementwiseOperation,
594 BElementwiseOperation,
595 AccElementwiseOperation,
596 CElementwiseOperation,
599 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
600 typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
602 false>;
603 ave_time = launch_and_time_kernel(stream_config,
604 kernel,
605 dim3(grid_size),
606 dim3(BlockSize),
607 0,
608 arg.p_a_grid_,
609 arg.p_b_grid_,
610 arg.p_c_grid_,
611 arg.p_c0_grid_bias_,
612 arg.p_c0_grid_add_,
614 arg.p_c0_grid_beta_,
615 arg.a_element_op_,
616 arg.b_element_op_,
617 arg.acc_element_op_,
618 arg.c_element_op_,
621 c_grid_desc_mblock_mperblock_nblock_nperblock,
622 c0_grid_desc_nblock_nperblock,
624 }
625
626 return ave_time;
627 }
628
630 // polymorphic
631 float Run(const BaseArgument* p_arg,
632 const StreamConfig& stream_config = StreamConfig{}) override
633 {
634 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
635 }
636 };
637
638 static constexpr bool IsValidCompilationParameter()
639 {
640 // TODO: properly implement this check
641 return true;
642 }
643
644 static bool IsSupportedArgument(const Argument& arg)
645 {
647 {
648 return false;
649 }
650 if(get_warp_size() == 64)
651 {
652 if constexpr(NXdlPerWave64 > 0)
653 {
658 }
659 }
660 else
661 {
662 if constexpr(NXdlPerWave32 > 0)
663 {
668 }
669 }
670 return false;
671 }
672
673 // polymorphic
674 bool IsSupportedArgument(const BaseArgument* p_arg) override
675 {
676 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
677 }
678
679 static auto MakeArgument(const ADataType* p_a,
680 const BDataType* p_b,
681 CDataType* p_c,
682 const C0DataType* p_c0_bias,
683 const C0DataType* p_c0_add,
684 const C0DataType* p_c0_gamma,
685 const C0DataType* p_c0_beta,
686 index_t MRaw,
687 index_t NRaw,
688 index_t KRaw,
689 index_t StrideA,
690 index_t StrideB,
691 index_t StrideC,
692 AElementwiseOperation a_element_op,
693 BElementwiseOperation b_element_op,
694 AccElementwiseOperation acc_element_op,
695 CElementwiseOperation c_element_op)
696 {
697 return Argument{p_a,
698 p_b,
699 p_c,
700 p_c0_bias,
701 p_c0_add,
702 p_c0_gamma,
703 p_c0_beta,
704 MRaw,
705 NRaw,
706 KRaw,
707 StrideA,
708 StrideB,
709 StrideC,
710 a_element_op,
711 b_element_op,
712 acc_element_op,
713 c_element_op};
714 }
715
716 static auto MakeInvoker() { return Invoker{}; }
717
718 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
719 const void* p_b,
720 void* p_c,
721 const void* p_c0_bias,
722 const void* p_c0_add,
723 const void* p_c0_gamma,
724 const void* p_c0_beta,
725 index_t MRaw,
726 index_t NRaw,
727 index_t KRaw,
728 index_t StrideA,
729 index_t StrideB,
730 index_t StrideC,
731 AElementwiseOperation a_element_op,
732 BElementwiseOperation b_element_op,
733 AccElementwiseOperation acc_element_op,
734 CElementwiseOperation c_element_op,
735 index_t /* KBatch */ = 1)
736 {
737 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
738 static_cast<const BDataType*>(p_b),
739 static_cast<CDataType*>(p_c),
740 static_cast<const C0DataType*>(p_c0_bias),
741 static_cast<const C0DataType*>(p_c0_add),
742 static_cast<const C0DataType*>(p_c0_gamma),
743 static_cast<const C0DataType*>(p_c0_beta),
744 MRaw,
745 NRaw,
746 KRaw,
747 StrideA,
748 StrideB,
749 StrideC,
750 a_element_op,
751 b_element_op,
752 acc_element_op,
753 c_element_op);
754 }
755
756 std::unique_ptr<BaseInvoker> MakeInvokerPointer()
757 {
758 return std::make_unique<Invoker>(Invoker{});
759 }
760
761 // polymorphic
762 std::string GetTypeString() const override
763 {
764 auto str = std::stringstream();
765
766 // clang-format off
767 str << "DeviceGemmLayerNorm_Xdl_CShuffle"
768 << "<"
769 << BlockSize << ", "
770 << MPerBlock << ", "
771 << NPerBlock << ", "
772 << KPerBlock << ", "
773 << AK1 << ", "
774 << BK1 << ", "
775 << MPerXDL << ", "
776 << NPerXDL << ", "
777 << MXdlPerWave << ", "
778 << NXdlPerWave << ", "
779 << ABlockTransferSrcScalarPerVector << ", "
780 << BBlockTransferSrcScalarPerVector << ", "
781 << CShuffleMXdlPerWavePerShuffle << ", "
782 << CShuffleNXdlPerWavePerShuffle
783 << ">";
784 // clang-format on
785
786 return str.str();
787 }
788};
789
790} // namespace device
791} // namespace tensor_operation
792} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
__global__ void kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_c0_bias_grid, const FloatC0 *__restrict__ p_c0_add_grid, const FloatC0 *__restrict__ p_c0_gamma_grid, const FloatC0 *__restrict__ p_c0_beta_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:41
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:160
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:268
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition device_gemm_xdl_layernorm_cshuffle.hpp:445
C0GridDesc_N c0_grid_desc_n_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:493
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:492
BElementwiseOperation b_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:496
CElementwiseOperation c_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:498
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:490
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const C0DataType *p_c0_grid_add, const C0DataType *p_c0_grid_bias, const C0DataType *p_c0_grid_gamma, const C0DataType *p_c0_grid_beta, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:446
const C0DataType * p_c0_grid_gamma_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:488
AccElementwiseOperation acc_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:497
const BDataType * p_b_grid_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:484
const C0DataType * p_c0_grid_add_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:487
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:491
const C0DataType * p_c0_grid_beta_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:489
const ADataType * p_a_grid_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:483
AElementwiseOperation a_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:495
const C0DataType * p_c0_grid_bias_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:486
CDataType * p_c_grid_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:485
Block2CTileMap block_2_ctile_map_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:494
Definition device_gemm_xdl_layernorm_cshuffle.hpp:503
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_layernorm_cshuffle.hpp:507
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_layernorm_cshuffle.hpp:631
DeviceOp::Argument Argument
Definition device_gemm_xdl_layernorm_cshuffle.hpp:504
Definition device_gemm_xdl_layernorm_cshuffle.hpp:81
std::string GetTypeString() const override
Definition device_gemm_xdl_layernorm_cshuffle.hpp:762
decltype(MakeGridDescriptor_N(1)) C0GridDesc_N
Definition device_gemm_xdl_layernorm_cshuffle.hpp:384
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_layernorm_cshuffle.hpp:85
static auto MakeInvoker()
Definition device_gemm_xdl_layernorm_cshuffle.hpp:716
std::unique_ptr< BaseInvoker > MakeInvokerPointer()
Definition device_gemm_xdl_layernorm_cshuffle.hpp:756
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_xdl_layernorm_cshuffle.hpp:383
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_layernorm_cshuffle.hpp:438
static auto MakeGridDescriptor_N(index_t NRaw)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:356
static constexpr auto I1
Definition device_gemm_xdl_layernorm_cshuffle.hpp:89
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_layernorm_cshuffle.hpp:86
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, const C0DataType *p_c0_bias, const C0DataType *p_c0_add, const C0DataType *p_c0_gamma, const C0DataType *p_c0_beta, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:679
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, const void *p_c0_bias, const void *p_c0_add, const void *p_c0_gamma, const void *p_c0_beta, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op, index_t=1)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:718
static constexpr auto I0
Definition device_gemm_xdl_layernorm_cshuffle.hpp:88
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_layernorm_cshuffle.hpp:674
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:195
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:644
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:298
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_gemm_xdl_layernorm_cshuffle.hpp:382
typename GridwiseGemm64::DefaultBlock2CTileMap Block2CTileMap
Definition device_gemm_xdl_layernorm_cshuffle.hpp:441
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:92
GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, C0DataType, ReduceAccDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, NXdlPerWave_, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_xdl_layernorm_cshuffle.hpp:388
static constexpr auto I2
Definition device_gemm_xdl_layernorm_cshuffle.hpp:90
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_layernorm_cshuffle.hpp:638
DeviceGemmLayerNorm_Xdl_CShuffle DeviceOp
Definition device_gemm_xdl_layernorm_cshuffle.hpp:82
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_layernorm_cshuffle.hpp:439
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_gemm_xdl_layernorm_cshuffle.hpp:381
#define CK_ENV(name)
Definition utility/env.hpp:129