device_gemm_reduce_xdl_cshuffle.hpp Source File

device_gemm_reduce_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_reduce_xdl_cshuffle.hpp Source File
device_gemm_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24// version currently has compiler issues with register spill which further causes validation
25// failures.
26template <typename ALayout,
27 typename BLayout,
28 typename CLayout,
29 typename ADataType,
30 typename BDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename ReduceAccDataType,
35 typename ReducePtrsGlobal,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
39 typename ReduceOperations,
40 typename ReduceInElementwiseOperations,
41 typename ReduceAccElementwiseOperations,
42 typename ReduceGlobalMemoryDataOperation,
43 GemmSpecialization GemmSpec,
44 index_t NumGemmKPrefetchStage,
45 index_t BlockSize,
46 index_t MPerBlock,
47 index_t NPerBlock,
48 index_t KPerBlock,
49 index_t AK1,
50 index_t BK1,
51 index_t MPerXDL,
52 index_t NPerXDL,
53 index_t MXdlPerWave,
54 index_t NXdlPerWave,
55 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
56 typename ABlockTransferThreadClusterArrangeOrder,
57 typename ABlockTransferSrcAccessOrder,
58 index_t ABlockTransferSrcVectorDim,
59 index_t ABlockTransferSrcScalarPerVector,
60 index_t ABlockTransferDstScalarPerVector_AK1,
61 bool ABlockLdsExtraM,
62 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
63 typename BBlockTransferThreadClusterArrangeOrder,
64 typename BBlockTransferSrcAccessOrder,
65 index_t BBlockTransferSrcVectorDim,
66 index_t BBlockTransferSrcScalarPerVector,
67 index_t BBlockTransferDstScalarPerVector_BK1,
68 bool BBlockLdsExtraN,
69 index_t CShuffleMXdlPerWavePerShuffle,
70 index_t CShuffleNXdlPerWavePerShuffle,
71 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
74 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
75 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
77struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
78{
80
82 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
83 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
84
85 static constexpr auto I0 = Number<0>{};
86 static constexpr auto I1 = Number<1>{};
87 static constexpr auto I2 = Number<2>{};
88
89 static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
90 {
91 const auto a_grid_desc_mraw_kraw = [&]() {
93 {
95 make_tuple(StrideA, I1));
96 }
98 {
100 make_tuple(I1, StrideA));
101 }
102 }();
103
104 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
105 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
106
107 const auto MPad = M - MRaw;
108 const auto KPad = K - KRaw;
109
110 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
112 {
113 // pad both M and K
114 assert(K % AK1 == 0);
115
116 const auto AK0 = K / AK1;
117
118 const auto a_grid_desc_m_k =
119 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
121 make_right_pad_transform(KRaw, KPad)),
124
125 const auto a_grid_desc_ak0_m_ak1 =
126 transform_tensor_descriptor(a_grid_desc_m_k,
131
132 return a_grid_desc_ak0_m_ak1;
133 }
134 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
136 {
137 // pad M, but not K
138 assert(KRaw % AK1 == 0);
139
140 const auto AK0 = KRaw / AK1;
141
142 const auto a_grid_desc_ak0_m_ak1 =
143 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
145 make_right_pad_transform(MRaw, MPad)),
148
149 return a_grid_desc_ak0_m_ak1;
150 }
151 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
153 {
154 // pad K, but not M
155 assert(K % AK1 == 0);
156
157 const auto AK0 = K / AK1;
158
159 const auto a_grid_desc_m_k = transform_tensor_descriptor(
160 a_grid_desc_mraw_kraw,
164
165 const auto a_grid_desc_ak0_m_ak1 =
166 transform_tensor_descriptor(a_grid_desc_m_k,
171
172 return a_grid_desc_ak0_m_ak1;
173 }
174 else
175 {
176 // not pad M or K
177 assert(KRaw % AK1 == 0);
178
179 const auto AK0 = KRaw / AK1;
180
181 const auto a_grid_desc_ak0_m_ak1 =
182 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
187
188 return a_grid_desc_ak0_m_ak1;
189 }
190 }
191
192 static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
193 {
194 const auto b_grid_desc_nraw_kraw = [&]() {
196 {
197 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
198 make_tuple(I1, StrideB));
199 }
201 {
202 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
203 make_tuple(StrideB, I1));
204 }
205 }();
206
207 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
208 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
209
210 const auto NPad = N - NRaw;
211 const auto KPad = K - KRaw;
212
213 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
215 {
216 // pad both N and K
217 assert(K % BK1 == 0);
218
219 const auto BK0 = K / BK1;
220
221 const auto b_grid_desc_n_k =
222 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
224 make_right_pad_transform(KRaw, KPad)),
227
228 const auto b_grid_desc_bk0_n_bk1 =
229 transform_tensor_descriptor(b_grid_desc_n_k,
234
235 return b_grid_desc_bk0_n_bk1;
236 }
237 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
239 {
240 // pad N, but not K
241 assert(KRaw % BK1 == 0);
242
243 const auto BK0 = KRaw / BK1;
244
245 const auto b_grid_desc_bk0_n_bk1 =
246 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
248 make_right_pad_transform(NRaw, NPad)),
251
252 return b_grid_desc_bk0_n_bk1;
253 }
254 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
256 {
257 // pad K, but not N
258 assert(K % BK1 == 0);
259
260 const auto BK0 = K / BK1;
261
262 const auto b_grid_desc_n_k = transform_tensor_descriptor(
263 b_grid_desc_nraw_kraw,
267
268 const auto b_grid_desc_bk0_n_bk1 =
269 transform_tensor_descriptor(b_grid_desc_n_k,
274
275 return b_grid_desc_bk0_n_bk1;
276 }
277 else
278 {
279 // not pad N or K
280 assert(KRaw % BK1 == 0);
281
282 const auto BK0 = KRaw / BK1;
283
284 const auto b_grid_desc_bk0_n_bk1 =
285 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
290
291 return b_grid_desc_bk0_n_bk1;
292 }
293 }
294
295 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
296 {
297 const auto c_grid_desc_mraw_nraw = [&]() {
299 {
300 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
301 make_tuple(StrideC, I1));
302 }
304 {
305 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
306 make_tuple(I1, StrideC));
307 }
308 }();
309
310 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
311 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
312
313 const auto MPad = M - MRaw;
314 const auto NPad = N - NRaw;
315
316 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
318 {
319 // pad M and N
320 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
322 make_right_pad_transform(NRaw, NPad)),
325 }
326 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
328 {
329 // pad M, but not N
331 c_grid_desc_mraw_nraw,
335 }
336 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
338 {
339 // pad N, but not M
341 c_grid_desc_mraw_nraw,
345 }
346 else
347 {
348 // not pad M or N
349 return c_grid_desc_mraw_nraw;
350 }
351 }
352
353 // assume Reduce is packed tensor
355 {
356 const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
357
358 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
359 const auto MPad = M - MRaw;
360
361 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
362 GemmSpec == GemmSpecialization::MNPadding ||
363 GemmSpec == GemmSpecialization::MKPadding ||
365 {
366 // pad M
367 return transform_tensor_descriptor(d_grid_desc_mraw,
371 }
372 else
373 {
374 // not pad M
375 return d_grid_desc_mraw;
376 }
377 }
378
381 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
383
384 // GridwiseGemm
385 template <index_t NXdlPerWave_>
387 ADataType, // TODO: distinguish A/B datatype
388 GemmAccDataType,
389 CShuffleDataType,
390 CDataType,
391 ReduceAccDataType,
392 ReducePtrsGlobal,
393 AElementwiseOperation,
394 BElementwiseOperation,
395 CElementwiseOperation,
396 ReduceOperations,
397 ReduceInElementwiseOperations,
398 ReduceAccElementwiseOperations,
400 ReduceGlobalMemoryDataOperation,
405 NumGemmKPrefetchStage,
406 BlockSize,
407 MPerBlock,
408 NPerBlock,
409 KPerBlock,
410 AK1,
411 BK1,
412 MPerXDL,
413 NPerXDL,
414 MXdlPerWave,
415 NXdlPerWave_,
416 ABlockTransferThreadClusterLengths_AK0_M_AK1,
417 ABlockTransferThreadClusterArrangeOrder,
418 ABlockTransferSrcAccessOrder,
419 ABlockTransferSrcVectorDim,
420 ABlockTransferSrcScalarPerVector,
421 ABlockTransferDstScalarPerVector_AK1,
422 false,
423 ABlockLdsExtraM,
424 BBlockTransferThreadClusterLengths_BK0_N_BK1,
425 BBlockTransferThreadClusterArrangeOrder,
426 BBlockTransferSrcAccessOrder,
427 BBlockTransferSrcVectorDim,
428 BBlockTransferSrcScalarPerVector,
429 BBlockTransferDstScalarPerVector_BK1,
430 false,
431 BBlockLdsExtraN,
432 CShuffleMXdlPerWavePerShuffle,
433 CShuffleNXdlPerWavePerShuffle,
434 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
435 CShuffleBlockTransferScalarPerVector_NPerBlock,
436 CReduceThreadClusterLengths_MPerBlock_NPerBlock,
437 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
438 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
439 LoopSched>;
442
443 // Argument
444 struct Argument : public BaseArgument
445 {
446 Argument(const ADataType* p_a_grid,
447 const BDataType* p_b_grid,
448 CDataType* p_c_grid,
449 ReducePtrsGlobal p_reduces_grid,
450 index_t MRaw,
451 index_t NRaw,
452 index_t KRaw,
453 index_t StrideA,
454 index_t StrideB,
455 index_t StrideC,
456 AElementwiseOperation a_element_op,
457 BElementwiseOperation b_element_op,
458 CElementwiseOperation c_element_op,
459 ReduceInElementwiseOperations reduce_in_element_ops,
460 ReduceAccElementwiseOperations reduce_out_element_ops)
461 : p_a_grid_{p_a_grid},
462 p_b_grid_{p_b_grid},
463 p_c_grid_{p_c_grid},
464 p_reduces_grid_{p_reduces_grid},
469 block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
470 a_element_op_{a_element_op},
471 b_element_op_{b_element_op},
472 c_element_op_{c_element_op},
473 reduce_in_element_ops_{reduce_in_element_ops},
474 reduce_out_element_ops_{reduce_out_element_ops}
475 {
476 }
477
478 // private:
479 const ADataType* p_a_grid_;
480 const BDataType* p_b_grid_;
481 CDataType* p_c_grid_;
482 ReducePtrsGlobal p_reduces_grid_;
488 AElementwiseOperation a_element_op_;
489 BElementwiseOperation b_element_op_;
490 CElementwiseOperation c_element_op_;
491 ReduceInElementwiseOperations reduce_in_element_ops_;
492 ReduceAccElementwiseOperations reduce_out_element_ops_;
493 };
494
495 // Invoker
496 struct Invoker : public BaseInvoker
497 {
499
500 template <typename GridwiseGemm>
501 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
502 {
503 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
504 {
505 std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
506 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
507 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
508 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
509
510 std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
511 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
512 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
513 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
514
515 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
516 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
517
518 std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
519 << "}" << std::endl;
520 }
521
522 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
526 {
527 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
528 }
529 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
530 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
531 arg.c_grid_desc_m_n_);
532
533 auto reduce_grid_desc_mblock_mperblock =
534 GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_);
535
536 const index_t grid_size =
537 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
538
539 const auto K =
540 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
541
542 float elapsed_time = 0.0f;
543 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
544 {
545 const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
546 GridwiseGemm,
547 ADataType, // TODO: distiguish A/B datatype
548 CDataType,
549 ReducePtrsGlobal,
550 AElementwiseOperation,
551 BElementwiseOperation,
552 CElementwiseOperation,
553 ReduceInElementwiseOperations,
554 ReduceAccElementwiseOperations,
557 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
558 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
559 typename GridwiseGemm::DefaultBlock2CTileMap,
560 true>;
561
562 elapsed_time = launch_and_time_kernel(stream_config,
563 kernel,
564 dim3(grid_size),
565 dim3(BlockSize),
566 0,
567 arg.p_a_grid_,
568 arg.p_b_grid_,
569 arg.p_c_grid_,
570 arg.p_reduces_grid_,
571 arg.a_element_op_,
572 arg.b_element_op_,
573 arg.c_element_op_,
578 c_grid_desc_mblock_mperblock_nblock_nperblock,
579 reduce_grid_desc_mblock_mperblock,
581 }
582 else
583 {
584 const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
585 GridwiseGemm,
586 ADataType, // TODO: distiguish A/B datatype
587 CDataType,
588 ReducePtrsGlobal,
589 AElementwiseOperation,
590 BElementwiseOperation,
591 CElementwiseOperation,
592 ReduceInElementwiseOperations,
593 ReduceAccElementwiseOperations,
596 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
597 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
598 typename GridwiseGemm::DefaultBlock2CTileMap,
599 false>;
600
601 elapsed_time = launch_and_time_kernel(stream_config,
602 kernel,
603 dim3(grid_size),
604 dim3(BlockSize),
605 0,
606 arg.p_a_grid_,
607 arg.p_b_grid_,
608 arg.p_c_grid_,
609 arg.p_reduces_grid_,
610 arg.a_element_op_,
611 arg.b_element_op_,
612 arg.c_element_op_,
617 c_grid_desc_mblock_mperblock_nblock_nperblock,
618 reduce_grid_desc_mblock_mperblock,
620 }
621
622 return elapsed_time;
623 }
624
626
627 // polymorphic
628 float Run(const BaseArgument* p_arg,
629 const StreamConfig& stream_config = StreamConfig{}) override
630 {
631 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
632 }
633 };
634
635 static constexpr bool IsValidCompilationParameter()
636 {
637 // TODO: properly implement this check
638 return true;
639 }
640
641 static bool IsSupportedArgument(const Argument& arg)
642 {
644 {
645 return false;
646 }
647 if(get_warp_size() == 64)
648 {
649 if constexpr(NXdlPerWave64 > 0)
650 {
655 }
656 }
657 else
658 {
659 if constexpr(NXdlPerWave32 > 0)
660 {
665 }
666 }
667 return false;
668 }
669
670 // polymorphic
671 bool IsSupportedArgument(const BaseArgument* p_arg) override
672 {
673 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
674 }
675
676 static constexpr int NumReduce = ReduceOperations::Size();
677 static auto MakeArgument(const void* p_a,
678 const void* p_b,
679 const void* p_bias,
680 std::array<const void*, 0> p_ds,
681 void* p_c,
682 std::array<void*, NumReduce> p_reduces,
683 ck::index_t M,
684 ck::index_t N,
685 ck::index_t K,
686 ck::index_t StrideA,
687 ck::index_t StrideB,
688 ck::index_t StrideC,
689 std::array<ck::index_t, 0> StrideDs,
690 std::array<void*, 3> gemm_element_ops,
691 std::array<void*, 0> d_element_ops,
692 std::array<void*, NumReduce> reduce_in_element_op,
693 std::array<void*, NumReduce> reduce_out_element_op)
694 {
695 (void)p_bias;
696 (void)p_ds;
697 (void)StrideDs;
698 (void)d_element_ops;
699
700 ReducePtrsGlobal reduce_tuple = generate_tuple(
701 [&](auto I) {
702 auto tmp = ReducePtrsGlobal{}[I];
703 using T = remove_pointer_t<decltype(tmp)>;
704 return static_cast<T*>(p_reduces[I]);
705 },
707
708 ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
709 [&](auto I) {
710 auto tmp = ReduceInElementwiseOperations{}[I];
711 using T = remove_pointer_t<decltype(tmp)>;
712 return *(static_cast<T*>(reduce_in_element_op[I]));
713 },
715 ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
716 [&](auto I) {
717 auto tmp = ReduceAccElementwiseOperations{}[I];
718 using T = remove_pointer_t<decltype(tmp)>;
719 return *(static_cast<T*>(reduce_out_element_op[I]));
720 },
722
723 AElementwiseOperation a_element_op =
724 *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
725 BElementwiseOperation b_element_op =
726 *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
727 CElementwiseOperation c_element_op =
728 *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
729
730 return Argument{static_cast<const ADataType*>(p_a),
731 static_cast<const BDataType*>(p_b),
732 static_cast<CDataType*>(p_c),
733 reduce_tuple,
734 M,
735 N,
736 K,
737 StrideA,
738 StrideB,
739 StrideC,
740 a_element_op,
741 b_element_op,
742 c_element_op,
743 reduce_in_element_ops,
744 reduce_out_element_ops};
745 }
746
747 static auto MakeInvoker() { return Invoker{}; }
748
749 // polymorphic
750 std::unique_ptr<BaseArgument>
751 MakeArgumentPointer(const void* p_a,
752 const void* p_b,
753 const void* p_bias,
754 std::array<const void*, 0> p_ds,
755 void* p_c,
756 std::array<void*, NumReduce> p_reduces,
757 ck::index_t M,
758 ck::index_t N,
759 ck::index_t K,
760 ck::index_t StrideA,
761 ck::index_t StrideB,
762 ck::index_t StrideC,
763 std::array<ck::index_t, 0> StrideDs,
764 std::array<void*, 3> gemm_element_ops,
765 std::array<void*, 0> d_element_ops,
766 std::array<void*, NumReduce> reduce_in_element_op,
767 std::array<void*, NumReduce> reduce_out_element_op,
768 ck::index_t = 1) override
769 {
770 (void)p_bias;
771 (void)p_ds;
772 (void)StrideDs;
773 (void)d_element_ops;
774
775 ReducePtrsGlobal reduce_tuple = generate_tuple(
776 [&](auto I) {
777 auto tmp = ReducePtrsGlobal{}[I];
778 using T = remove_pointer_t<decltype(tmp)>;
779 return static_cast<T*>(p_reduces[I]);
780 },
782
783 ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
784 [&](auto I) {
785 auto tmp = ReduceInElementwiseOperations{}[I];
786 using T = remove_pointer_t<decltype(tmp)>;
787 return *(static_cast<T*>(reduce_in_element_op[I]));
788 },
790 ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
791 [&](auto I) {
792 auto tmp = ReduceAccElementwiseOperations{}[I];
793 using T = remove_pointer_t<decltype(tmp)>;
794 return *(static_cast<T*>(reduce_out_element_op[I]));
795 },
797
798 AElementwiseOperation a_element_op =
799 *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
800 BElementwiseOperation b_element_op =
801 *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
802 CElementwiseOperation c_element_op =
803 *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
804
805 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
806 static_cast<const BDataType*>(p_b),
807 static_cast<CDataType*>(p_c),
808 reduce_tuple,
809 M,
810 N,
811 K,
812 StrideA,
813 StrideB,
814 StrideC,
815 a_element_op,
816 b_element_op,
817 c_element_op,
818 reduce_in_element_ops,
819 reduce_out_element_ops);
820 }
821
822 // polymorphic
823 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
824 {
825 return std::make_unique<Invoker>(Invoker{});
826 }
827
828 // polymorphic
829 std::string GetTypeString() const override
830 {
831 auto str = std::stringstream();
832
833 // clang-format off
834 str << "DeviceGemmReduce_Xdl_CShuffle"
835 << "<"
836 << BlockSize << ", "
837 << MPerBlock << ", "
838 << NPerBlock << ", "
839 << KPerBlock << ", "
840 << AK1 << ", "
841 << BK1 << ", "
842 << MPerXDL << ", "
843 << NPerXDL << ", "
844 << MXdlPerWave << ", "
845 << NXdlPerWave << ", "
846 << ABlockTransferSrcScalarPerVector << ", "
847 << BBlockTransferSrcScalarPerVector << ", "
848 << CShuffleMXdlPerWavePerShuffle << ", "
849 << CShuffleNXdlPerWavePerShuffle
850 << ">";
851 // clang-format on
852
853 return str.str();
854 }
855};
856
857} // namespace device
858} // namespace tensor_operation
859} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:40
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:152
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:347
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:251
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition device_gemm_reduce_xdl_cshuffle.hpp:445
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_reduce_xdl_cshuffle.hpp:485
const BDataType * p_b_grid_
Definition device_gemm_reduce_xdl_cshuffle.hpp:480
ReducePtrsGlobal p_reduces_grid_
Definition device_gemm_reduce_xdl_cshuffle.hpp:482
BElementwiseOperation b_element_op_
Definition device_gemm_reduce_xdl_cshuffle.hpp:489
CDataType * p_c_grid_
Definition device_gemm_reduce_xdl_cshuffle.hpp:481
CElementwiseOperation c_element_op_
Definition device_gemm_reduce_xdl_cshuffle.hpp:490
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops)
Definition device_gemm_reduce_xdl_cshuffle.hpp:446
ReduceInElementwiseOperations reduce_in_element_ops_
Definition device_gemm_reduce_xdl_cshuffle.hpp:491
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_reduce_xdl_cshuffle.hpp:483
AElementwiseOperation a_element_op_
Definition device_gemm_reduce_xdl_cshuffle.hpp:488
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition device_gemm_reduce_xdl_cshuffle.hpp:492
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_reduce_xdl_cshuffle.hpp:484
const ADataType * p_a_grid_
Definition device_gemm_reduce_xdl_cshuffle.hpp:479
ReduceGridDesc_M reduce_grid_desc_m_
Definition device_gemm_reduce_xdl_cshuffle.hpp:486
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_reduce_xdl_cshuffle.hpp:487
Definition device_gemm_reduce_xdl_cshuffle.hpp:497
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_reduce_xdl_cshuffle.hpp:501
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_reduce_xdl_cshuffle.hpp:628
DeviceOp::Argument Argument
Definition device_gemm_reduce_xdl_cshuffle.hpp:498
Definition device_gemm_reduce_xdl_cshuffle.hpp:78
static constexpr auto NXdlPerWave32
Definition device_gemm_reduce_xdl_cshuffle.hpp:83
static constexpr auto I0
Definition device_gemm_reduce_xdl_cshuffle.hpp:85
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, ck::index_t=1) override
Definition device_gemm_reduce_xdl_cshuffle.hpp:751
static constexpr auto I2
Definition device_gemm_reduce_xdl_cshuffle.hpp:87
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_reduce_xdl_cshuffle.hpp:440
static constexpr auto I1
Definition device_gemm_reduce_xdl_cshuffle.hpp:86
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_reduce_xdl_cshuffle.hpp:295
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_reduce_xdl_cshuffle.hpp:635
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_reduce_xdl_cshuffle.hpp:823
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_reduce_xdl_cshuffle.hpp:192
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_reduce_xdl_cshuffle.hpp:381
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition device_gemm_reduce_xdl_cshuffle.hpp:382
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_reduce_xdl_cshuffle.hpp:671
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_reduce_xdl_cshuffle.hpp:641
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition device_gemm_reduce_xdl_cshuffle.hpp:354
static auto MakeInvoker()
Definition device_gemm_reduce_xdl_cshuffle.hpp:747
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_reduce_xdl_cshuffle.hpp:82
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_reduce_xdl_cshuffle.hpp:89
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_gemm_reduce_xdl_cshuffle.hpp:379
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_gemm_reduce_xdl_cshuffle.hpp:380
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_reduce_xdl_cshuffle.hpp:386
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op)
Definition device_gemm_reduce_xdl_cshuffle.hpp:677
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_reduce_xdl_cshuffle.hpp:441
DeviceGemmReduce_Xdl_CShuffle DeviceOp
Definition device_gemm_reduce_xdl_cshuffle.hpp:79
static constexpr int NumReduce
Definition device_gemm_reduce_xdl_cshuffle.hpp:676
std::string GetTypeString() const override
Definition device_gemm_reduce_xdl_cshuffle.hpp:829
Definition device_gemm_reduce.hpp:17
#define CK_ENV(name)
Definition utility/env.hpp:129