device_batched_gemm_reduce_xdl_cshuffle.hpp Source File

device_batched_gemm_reduce_xdl_cshuffle.hpp Source File#

Composable Kernel: device_batched_gemm_reduce_xdl_cshuffle.hpp Source File
device_batched_gemm_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename GridwiseGemm,
25 typename FloatAB,
26 typename FloatC,
27 typename ReducePtrsGlobal,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CElementwiseOperation,
31 typename ReduceInElementwiseOperations,
32 typename ReduceAccElementwiseOperations,
33 typename AGridDesc_AK0_M_AK1,
34 typename BGridDesc_BK0_N_BK1,
35 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename ReduceGridDescriptor_MBlock_MPerBlock,
37 typename ComputeBasePrtOfBatch,
38 typename Block2CTileMap,
39 bool HasMainK0BlockLoop>
40__global__ void
41#if CK_USE_LAUNCH_BOUNDS
43#endif
45 const FloatAB* __restrict__ p_a_grid,
46 const FloatAB* __restrict__ p_b_grid,
47 FloatC* __restrict__ p_c_grid,
48 ReducePtrsGlobal p_reduces_grid,
49 const index_t batch_count,
50 const AElementwiseOperation a_element_op,
51 const BElementwiseOperation b_element_op,
52 const CElementwiseOperation c_element_op,
53 const ReduceInElementwiseOperations reduce_in_element_ops,
54 const ReduceAccElementwiseOperations reduce_out_element_ops,
55 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
56 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
57 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
58 c_grid_desc_mblock_mperblock_nblock_nperblock,
59 const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
60 const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
61 const Block2CTileMap block_2_ctile_map)
62{
63#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
64 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
65 {
66 const index_t num_blocks_per_batch =
67 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
68 const index_t g_idx =
69 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
70
71 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
72 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
73 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
74 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
75 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
76 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
77
78 static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
79 const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
80 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
81 p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
82 });
83
84 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
85
86 GridwiseGemm::template Run<HasMainK0BlockLoop>(
87 p_a_grid + a_batch_offset,
88 p_b_grid + b_batch_offset,
89 p_c_grid + c_batch_offset,
90 p_reduces_grid,
91 p_shared,
92 a_element_op,
93 b_element_op,
94 c_element_op,
95 reduce_in_element_ops,
96 reduce_out_element_ops,
97 a_grid_desc_ak0_m_ak1,
98 b_grid_desc_bk0_n_bk1,
99 c_grid_desc_mblock_mperblock_nblock_nperblock,
100 reduce_grid_desc_mblock_mperblock,
101 block_2_ctile_map);
102 }
103#else
104 ignore = p_a_grid;
105 ignore = p_b_grid;
106 ignore = p_c_grid;
107 ignore = p_reduces_grid;
108 ignore = batch_count;
109 ignore = a_element_op;
110 ignore = b_element_op;
111 ignore = c_element_op;
112 ignore = reduce_in_element_ops;
113 ignore = reduce_out_element_ops;
114 ignore = a_grid_desc_ak0_m_ak1;
115 ignore = b_grid_desc_bk0_n_bk1;
116 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
117 ignore = reduce_grid_desc_mblock_mperblock;
118 ignore = compute_base_ptr_of_batch_;
119 ignore = block_2_ctile_map;
120#endif
121}
122
123// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
124// version currently has compiler issues with register spill which further causes validation
125// failures.
126template <typename ALayout,
127 typename BLayout,
128 typename CLayout,
129 typename ADataType,
130 typename BDataType,
131 typename CDataType,
132 typename GemmAccDataType,
133 typename CShuffleDataType,
134 typename ReduceAccDataType,
135 typename ReducePtrsGlobal,
136 typename AElementwiseOperation,
137 typename BElementwiseOperation,
138 typename CElementwiseOperation,
139 typename ReduceOperations,
140 typename ReduceInElementwiseOperations,
141 typename ReduceAccElementwiseOperations,
142 typename ReduceGlobalMemoryDataOperation,
143 GemmSpecialization GemmSpec,
144 index_t NumGemmKPrefetchStage,
145 index_t BlockSize,
146 index_t MPerBlock,
147 index_t NPerBlock,
148 index_t KPerBlock,
149 index_t AK1,
150 index_t BK1,
151 index_t MPerXDL,
152 index_t NPerXDL,
153 index_t MXdlPerWave,
154 index_t NXdlPerWave,
155 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
156 typename ABlockTransferThreadClusterArrangeOrder,
157 typename ABlockTransferSrcAccessOrder,
158 index_t ABlockTransferSrcVectorDim,
159 index_t ABlockTransferSrcScalarPerVector,
160 index_t ABlockTransferDstScalarPerVector_AK1,
161 bool ABlockLdsExtraM,
162 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
163 typename BBlockTransferThreadClusterArrangeOrder,
164 typename BBlockTransferSrcAccessOrder,
165 index_t BBlockTransferSrcVectorDim,
166 index_t BBlockTransferSrcScalarPerVector,
167 index_t BBlockTransferDstScalarPerVector_BK1,
168 bool BBlockLdsExtraN,
169 index_t CShuffleMXdlPerWavePerShuffle,
170 index_t CShuffleNXdlPerWavePerShuffle,
171 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
173 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
174 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
175 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
177struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
178{
181 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
182 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
183
184 static constexpr auto I0 = Number<0>{};
185 static constexpr auto I1 = Number<1>{};
186 static constexpr auto I2 = Number<2>{};
187
188 static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
189 {
190 const auto a_grid_desc_mraw_kraw = [&]() {
192 {
193 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
194 make_tuple(StrideA, I1));
195 }
197 {
198 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
199 make_tuple(I1, StrideA));
200 }
201 }();
202
203 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
204 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
205
206 const auto MPad = M - MRaw;
207 const auto KPad = K - KRaw;
208
209 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
211 {
212 // pad both M and K
213 assert(K % AK1 == 0);
214
215 const auto AK0 = K / AK1;
216
217 const auto a_grid_desc_m_k =
218 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
220 make_right_pad_transform(KRaw, KPad)),
223
224 const auto a_grid_desc_ak0_m_ak1 =
225 transform_tensor_descriptor(a_grid_desc_m_k,
230
231 return a_grid_desc_ak0_m_ak1;
232 }
233 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
235 {
236 // pad M, but not K
237 assert(KRaw % AK1 == 0);
238
239 const auto AK0 = KRaw / AK1;
240
241 const auto a_grid_desc_ak0_m_ak1 =
242 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
244 make_right_pad_transform(MRaw, MPad)),
247
248 return a_grid_desc_ak0_m_ak1;
249 }
250 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
252 {
253 // pad K, but not M
254 assert(K % AK1 == 0);
255
256 const auto AK0 = K / AK1;
257
258 const auto a_grid_desc_m_k = transform_tensor_descriptor(
259 a_grid_desc_mraw_kraw,
263
264 const auto a_grid_desc_ak0_m_ak1 =
265 transform_tensor_descriptor(a_grid_desc_m_k,
270
271 return a_grid_desc_ak0_m_ak1;
272 }
273 else
274 {
275 // not pad M or K
276 assert(KRaw % AK1 == 0);
277
278 const auto AK0 = KRaw / AK1;
279
280 const auto a_grid_desc_ak0_m_ak1 =
281 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
286
287 return a_grid_desc_ak0_m_ak1;
288 }
289 }
290
291 static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
292 {
293 const auto b_grid_desc_nraw_kraw = [&]() {
295 {
296 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
297 make_tuple(I1, StrideB));
298 }
300 {
301 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
302 make_tuple(StrideB, I1));
303 }
304 }();
305
306 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
307 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
308
309 const auto NPad = N - NRaw;
310 const auto KPad = K - KRaw;
311
312 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
314 {
315 // pad both N and K
316 assert(K % BK1 == 0);
317
318 const auto BK0 = K / BK1;
319
320 const auto b_grid_desc_n_k =
321 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
323 make_right_pad_transform(KRaw, KPad)),
326
327 const auto b_grid_desc_bk0_n_bk1 =
328 transform_tensor_descriptor(b_grid_desc_n_k,
333
334 return b_grid_desc_bk0_n_bk1;
335 }
336 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
338 {
339 // pad N, but not K
340 assert(KRaw % BK1 == 0);
341
342 const auto BK0 = KRaw / BK1;
343
344 const auto b_grid_desc_bk0_n_bk1 =
345 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
347 make_right_pad_transform(NRaw, NPad)),
350
351 return b_grid_desc_bk0_n_bk1;
352 }
353 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
355 {
356 // pad K, but not N
357 assert(K % BK1 == 0);
358
359 const auto BK0 = K / BK1;
360
361 const auto b_grid_desc_n_k = transform_tensor_descriptor(
362 b_grid_desc_nraw_kraw,
366
367 const auto b_grid_desc_bk0_n_bk1 =
368 transform_tensor_descriptor(b_grid_desc_n_k,
373
374 return b_grid_desc_bk0_n_bk1;
375 }
376 else
377 {
378 // not pad N or K
379 assert(KRaw % BK1 == 0);
380
381 const auto BK0 = KRaw / BK1;
382
383 const auto b_grid_desc_bk0_n_bk1 =
384 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
389
390 return b_grid_desc_bk0_n_bk1;
391 }
392 }
393
394 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
395 {
396 const auto c_grid_desc_mraw_nraw = [&]() {
398 {
399 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
400 make_tuple(StrideC, I1));
401 }
403 {
404 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
405 make_tuple(I1, StrideC));
406 }
407 }();
408
409 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
410 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
411
412 const auto MPad = M - MRaw;
413 const auto NPad = N - NRaw;
414
415 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
417 {
418 // pad M and N
419 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
421 make_right_pad_transform(NRaw, NPad)),
424 }
425 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
427 {
428 // pad M, but not N
430 c_grid_desc_mraw_nraw,
434 }
435 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
437 {
438 // pad N, but not M
440 c_grid_desc_mraw_nraw,
444 }
445 else
446 {
447 // not pad M or N
448 return c_grid_desc_mraw_nraw;
449 }
450 }
451
452 // assume D is packed tensor
454 {
455 const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
456
457 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
458 const auto MPad = M - MRaw;
459
460 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
461 GemmSpec == GemmSpecialization::MNPadding ||
462 GemmSpec == GemmSpecialization::MKPadding ||
464 {
465 // pad M
466 return transform_tensor_descriptor(d_grid_desc_mraw,
470 }
471 else
472 {
473 // not pad M
474 return d_grid_desc_mraw;
475 }
476 }
477
480 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
482
484 {
486 index_t BatchStrideB,
487 index_t BatchStrideC,
488 index_t BatchStrideD)
489 : BatchStrideA_(BatchStrideA),
490 BatchStrideB_(BatchStrideB),
491 BatchStrideC_(BatchStrideC),
492 BatchStrideD_(BatchStrideD)
493 {
494 }
495
496 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
497 {
498 return g_idx * static_cast<long_index_t>(BatchStrideA_);
499 }
500
501 __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
502 {
503 return g_idx * static_cast<long_index_t>(BatchStrideB_);
504 }
505
506 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
507 {
508 return g_idx * static_cast<long_index_t>(BatchStrideC_);
509 }
510
511 template <index_t I>
512 __host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx,
513 Number<I> reduction_idx) const
514 {
515 // TODO - Support sequence of StrideD in MakeArgument()
516 (void)reduction_idx;
517 return g_idx * static_cast<long_index_t>(BatchStrideD_);
518 }
519
520 private:
521 index_t BatchStrideA_;
522 index_t BatchStrideB_;
523 index_t BatchStrideC_;
524 index_t BatchStrideD_;
525 };
526
527 // GridwiseGemm
528 template <index_t NXdlPerWave_>
530 ADataType, // TODO: distinguish A/B datatype
531 GemmAccDataType,
532 CShuffleDataType,
533 CDataType,
534 ReduceAccDataType,
535 ReducePtrsGlobal,
536 AElementwiseOperation,
537 BElementwiseOperation,
538 CElementwiseOperation,
539 ReduceOperations,
540 ReduceInElementwiseOperations,
541 ReduceAccElementwiseOperations,
543 ReduceGlobalMemoryDataOperation,
548 NumGemmKPrefetchStage,
549 BlockSize,
550 MPerBlock,
551 NPerBlock,
552 KPerBlock,
553 AK1,
554 BK1,
555 MPerXDL,
556 NPerXDL,
557 MXdlPerWave,
558 NXdlPerWave_,
559 ABlockTransferThreadClusterLengths_AK0_M_AK1,
560 ABlockTransferThreadClusterArrangeOrder,
561 ABlockTransferSrcAccessOrder,
562 ABlockTransferSrcVectorDim,
563 ABlockTransferSrcScalarPerVector,
564 ABlockTransferDstScalarPerVector_AK1,
565 false,
566 ABlockLdsExtraM,
567 BBlockTransferThreadClusterLengths_BK0_N_BK1,
568 BBlockTransferThreadClusterArrangeOrder,
569 BBlockTransferSrcAccessOrder,
570 BBlockTransferSrcVectorDim,
571 BBlockTransferSrcScalarPerVector,
572 BBlockTransferDstScalarPerVector_BK1,
573 false,
574 BBlockLdsExtraN,
575 CShuffleMXdlPerWavePerShuffle,
576 CShuffleNXdlPerWavePerShuffle,
577 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
578 CShuffleBlockTransferScalarPerVector_NPerBlock,
579 CReduceThreadClusterLengths_MPerBlock_NPerBlock,
580 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
581 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
582 LoopSched>;
585
586 // Argument
587 struct Argument : public BaseArgument
588 {
589 Argument(const ADataType* p_a_grid,
590 const BDataType* p_b_grid,
591 CDataType* p_c_grid,
592 ReducePtrsGlobal p_reduces_grid,
593 index_t MRaw,
594 index_t NRaw,
595 index_t KRaw,
596 index_t StrideA,
597 index_t StrideB,
598 index_t StrideC,
599 AElementwiseOperation a_element_op,
600 BElementwiseOperation b_element_op,
601 CElementwiseOperation c_element_op,
602 ReduceInElementwiseOperations reduce_in_element_ops,
603 ReduceAccElementwiseOperations reduce_out_element_ops,
604 index_t Batch)
605 : p_a_grid_{p_a_grid},
606 p_b_grid_{p_b_grid},
607 p_c_grid_{p_c_grid},
608 p_reduces_grid_{p_reduces_grid},
609 Batch_(Batch),
615 type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
616 type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
617 type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
618 type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
619 block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
620 a_element_op_{a_element_op},
621 b_element_op_{b_element_op},
622 c_element_op_{c_element_op},
623 reduce_in_element_ops_{reduce_in_element_ops},
624 reduce_out_element_ops_{reduce_out_element_ops}
625 {
626 }
627
628 // private:
629 const ADataType* p_a_grid_;
630 const BDataType* p_b_grid_;
631 CDataType* p_c_grid_;
632 ReducePtrsGlobal p_reduces_grid_;
640 AElementwiseOperation a_element_op_;
641 BElementwiseOperation b_element_op_;
642 CElementwiseOperation c_element_op_;
643 ReduceInElementwiseOperations reduce_in_element_ops_;
644 ReduceAccElementwiseOperations reduce_out_element_ops_;
645 };
646
647 // Invoker
648 struct Invoker : public BaseInvoker
649 {
651
652 template <typename GridwiseGemm>
653 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
654 {
655 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
659 {
660 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
661 }
662
663 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
664 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
665 arg.c_grid_desc_m_n_);
666
667 auto reduce_grid_desc_mblock_mperblock =
668 GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_);
669
670 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
671 {
672 {
673 std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
674
675 std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
676 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
677 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
678 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
679
680 std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
681 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
682 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
683 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
684
685 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
686 << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
687
688 std::cout << "arg.reduce_grid_desc_m_{ "
689 << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl;
690 }
691 }
692 const index_t grid_size =
693 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
694
695 const auto K =
696 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
697
698 float elapsed_time = 0.0f;
699 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
700 {
702 GridwiseGemm,
703 ADataType, // TODO: distiguish A/B datatype
704 CDataType,
705 ReducePtrsGlobal,
706 AElementwiseOperation,
707 BElementwiseOperation,
708 CElementwiseOperation,
709 ReduceInElementwiseOperations,
710 ReduceAccElementwiseOperations,
713 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
714 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
715 ComputeBasePtrOfStridedBatch,
716 typename GridwiseGemm::DefaultBlock2CTileMap,
717 true>;
718
719 elapsed_time = launch_and_time_kernel(stream_config,
720 kernel,
721 dim3(grid_size),
722 dim3(BlockSize),
723 0,
724 arg.p_a_grid_,
725 arg.p_b_grid_,
726 arg.p_c_grid_,
727 arg.p_reduces_grid_,
728 arg.Batch_,
729 arg.a_element_op_,
730 arg.b_element_op_,
731 arg.c_element_op_,
736 c_grid_desc_mblock_mperblock_nblock_nperblock,
737 reduce_grid_desc_mblock_mperblock,
740 }
741 else
742 {
744 GridwiseGemm,
745 ADataType, // TODO: distiguish A/B datatype
746 CDataType,
747 ReducePtrsGlobal,
748 AElementwiseOperation,
749 BElementwiseOperation,
750 CElementwiseOperation,
751 ReduceInElementwiseOperations,
752 ReduceAccElementwiseOperations,
755 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
756 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
757 ComputeBasePtrOfStridedBatch,
758 typename GridwiseGemm::DefaultBlock2CTileMap,
759 false>;
760
761 elapsed_time = launch_and_time_kernel(stream_config,
762 kernel,
763 dim3(grid_size),
764 dim3(BlockSize),
765 0,
766 arg.p_a_grid_,
767 arg.p_b_grid_,
768 arg.p_c_grid_,
769 arg.p_reduces_grid_,
770 arg.Batch_,
771 arg.a_element_op_,
772 arg.b_element_op_,
773 arg.c_element_op_,
778 c_grid_desc_mblock_mperblock_nblock_nperblock,
779 reduce_grid_desc_mblock_mperblock,
782 }
783
784 return elapsed_time;
785 }
786
788
789 // polymorphic
790 float Run(const BaseArgument* p_arg,
791 const StreamConfig& stream_config = StreamConfig{}) override
792 {
793 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
794 }
795 };
796
797 static constexpr bool IsValidCompilationParameter()
798 {
799 // TODO: properly implement this check
800 return true;
801 }
802
803 static bool IsSupportedArgument(const Argument& arg)
804 {
806 {
807 return false;
808 }
809 if(get_warp_size() == 64)
810 {
811 if constexpr(NXdlPerWave64 > 0)
812 {
817 }
818 }
819 else
820 {
821 if constexpr(NXdlPerWave32 > 0)
822 {
827 }
828 }
829 return false;
830 }
831
832 // polymorphic
833 bool IsSupportedArgument(const BaseArgument* p_arg) override
834 {
835 auto casted_p_arg = dynamic_cast<const Argument*>(p_arg);
836 if(casted_p_arg == nullptr)
837 {
838 return false;
839 }
840 else
841 {
842 return IsSupportedArgument(*casted_p_arg);
843 }
844 }
845
846 static constexpr int NumReduce = ReduceOperations::Size();
847 static auto MakeArgument(const void* p_a,
848 const void* p_b,
849 const void* p_bias,
850 std::array<const void*, 0> p_ds,
851 void* p_c,
852 std::array<void*, NumReduce> p_reduces,
853 ck::index_t M,
854 ck::index_t N,
855 ck::index_t K,
856 ck::index_t StrideA,
857 ck::index_t StrideB,
858 ck::index_t StrideC,
859 std::array<ck::index_t, 0> StrideDs,
860 std::array<void*, 3> gemm_element_ops,
861 std::array<void*, 0> d_element_ops,
862 std::array<void*, NumReduce> reduce_in_element_op,
863 std::array<void*, NumReduce> reduce_out_element_op,
864 index_t Batch)
865 {
866 (void)p_bias;
867 (void)p_ds;
868 (void)StrideDs;
869 (void)d_element_ops;
870
871 ReducePtrsGlobal reduce_tuple = generate_tuple(
872 [&](auto I) {
873 auto tmp = ReducePtrsGlobal{}[I];
874 using T = remove_pointer_t<decltype(tmp)>;
875 return static_cast<T*>(p_reduces[I]);
876 },
878
879 ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
880 [&](auto I) {
881 auto tmp = ReduceInElementwiseOperations{}[I];
882 using T = remove_pointer_t<decltype(tmp)>;
883 return *(static_cast<T*>(reduce_in_element_op[I]));
884 },
886 ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
887 [&](auto I) {
888 auto tmp = ReduceAccElementwiseOperations{}[I];
889 using T = remove_pointer_t<decltype(tmp)>;
890 return *(static_cast<T*>(reduce_out_element_op[I]));
891 },
893
894 AElementwiseOperation a_element_op =
895 *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
896 BElementwiseOperation b_element_op =
897 *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
898 CElementwiseOperation c_element_op =
899 *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
900
901 return Argument{static_cast<const ADataType*>(p_a),
902 static_cast<const BDataType*>(p_b),
903 static_cast<CDataType*>(p_c),
904 reduce_tuple,
905 M,
906 N,
907 K,
908 StrideA,
909 StrideB,
910 StrideC,
911 a_element_op,
912 b_element_op,
913 c_element_op,
914 reduce_in_element_ops,
915 reduce_out_element_ops,
916 Batch};
917 }
918
919 static auto MakeInvoker() { return Invoker{}; }
920
921 // polymorphic
922 std::unique_ptr<BaseArgument>
923 MakeArgumentPointer(const void* p_a,
924 const void* p_b,
925 const void* p_bias,
926 std::array<const void*, 0> p_ds,
927 void* p_c,
928 std::array<void*, NumReduce> p_reduces,
929 ck::index_t M,
930 ck::index_t N,
931 ck::index_t K,
932 ck::index_t StrideA,
933 ck::index_t StrideB,
934 ck::index_t StrideC,
935 std::array<ck::index_t, 0> StrideDs,
936 std::array<void*, 3> gemm_element_ops,
937 std::array<void*, 0> d_element_ops,
938 std::array<void*, NumReduce> reduce_in_element_op,
939 std::array<void*, NumReduce> reduce_out_element_op,
940 index_t Batch = 1) override
941 {
942 (void)p_bias;
943 (void)p_ds;
944 (void)StrideDs;
945 (void)d_element_ops;
946
947 ReducePtrsGlobal reduce_tuple = generate_tuple(
948 [&](auto I) {
949 auto tmp = ReducePtrsGlobal{}[I];
950 using T = remove_pointer_t<decltype(tmp)>;
951 return static_cast<T*>(p_reduces[I]);
952 },
954
955 ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
956 [&](auto I) {
957 auto tmp = ReduceInElementwiseOperations{}[I];
958 using T = remove_pointer_t<decltype(tmp)>;
959 return *(static_cast<T*>(reduce_in_element_op[I]));
960 },
962 ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
963 [&](auto I) {
964 auto tmp = ReduceAccElementwiseOperations{}[I];
965 using T = remove_pointer_t<decltype(tmp)>;
966 return *(static_cast<T*>(reduce_out_element_op[I]));
967 },
969
970 AElementwiseOperation a_element_op =
971 *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
972 BElementwiseOperation b_element_op =
973 *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
974 CElementwiseOperation c_element_op =
975 *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
976
977 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
978 static_cast<const BDataType*>(p_b),
979 static_cast<CDataType*>(p_c),
980 reduce_tuple,
981 M,
982 N,
983 K,
984 StrideA,
985 StrideB,
986 StrideC,
987 a_element_op,
988 b_element_op,
989 c_element_op,
990 reduce_in_element_ops,
991 reduce_out_element_ops,
992 Batch);
993 }
994
995 // polymorphic
996 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
997 {
998 return std::make_unique<Invoker>(Invoker{});
999 }
1000
1001 // polymorphic
1002 std::string GetTypeString() const override
1003 {
1004 auto str = std::stringstream();
1005
1006 // clang-format off
1007 str << "DeviceBatchedGemmReduce_Xdl_CShuffle"
1008 << "<"
1009 << BlockSize << ", "
1010 << MPerBlock << ", "
1011 << NPerBlock << ", "
1012 << KPerBlock << ", "
1013 << AK1 << ", "
1014 << BK1
1015 << ">";
1016 // clang-format on
1017
1018 return str.str();
1019 }
1020};
1021
1022} // namespace device
1023} // namespace tensor_operation
1024} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_batched_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:44
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:152
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:347
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:251
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:588
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:638
ReduceGridDesc_M reduce_grid_desc_m_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:637
BElementwiseOperation b_element_op_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:641
ReducePtrsGlobal p_reduces_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:632
index_t Batch_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:633
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:636
CDataType * p_c_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:631
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:639
CElementwiseOperation c_element_op_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:642
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:635
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops, index_t Batch)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:589
const ADataType * p_a_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:629
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:644
ReduceInElementwiseOperations reduce_in_element_ops_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:643
const BDataType * p_b_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:630
AElementwiseOperation a_element_op_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:640
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:634
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:501
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideD)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:485
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:496
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, Number< I > reduction_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:512
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:506
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:649
DeviceOp::Argument Argument
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:650
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:790
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:653
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:178
static auto MakeInvoker()
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:919
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:291
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:797
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:188
static constexpr auto I0
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:184
std::string GetTypeString() const override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:1002
static constexpr int NumReduce
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:846
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, index_t Batch)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:847
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:481
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:529
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:394
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:478
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:803
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:584
DeviceBatchedGemmReduce_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:179
static constexpr auto I2
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:186
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:583
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:479
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:996
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:453
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:181
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:182
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, index_t Batch=1) override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:923
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:480
static constexpr auto I1
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:185
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:833
Definition device_gemm_reduce.hpp:17
#define CK_ENV(name)
Definition utility/env.hpp:129