device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File

device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename GridwiseGemm,
26 typename GroupKernelArg,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename AccElementwiseOperation,
30 typename B1ElementwiseOperation,
31 typename CElementwiseOperation,
32 bool HasMainKBlockLoop>
33__global__ void
34#if CK_USE_LAUNCH_BOUNDS
36#endif
38 const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
39 const index_t group_count,
40 const AElementwiseOperation a_element_op,
41 const BElementwiseOperation b_element_op,
42 const AccElementwiseOperation acc_element_op,
43 const B1ElementwiseOperation b1_element_op,
44 const CElementwiseOperation c_element_op)
45{
46#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
47 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
48 {
49 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50
51 const index_t block_id = get_block_1d_id();
52
53 const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
54 cast_pointer_to_generic_address_space(group_kernel_args));
55
56 index_t left = 0;
57 index_t right = group_count;
58 index_t group_id = index_t((left + right) / 2);
59
60 while((!(block_id >= arg_ptr[group_id].block_start_ &&
61 block_id < arg_ptr[group_id].block_end_)))
62 {
63 if(block_id < arg_ptr[group_id].block_start_)
64 {
65 right = group_id;
66 }
67 else
68 {
69 left = group_id;
70 }
71 group_id = index_t((left + right) / 2);
72 }
73
74 // per-group batch offset
75 const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
76 const index_t g_idx = __builtin_amdgcn_readfirstlane(
77 (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
78
79 const long_index_t a_batch_offset =
80 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
81 arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
82 const long_index_t b_batch_offset =
83 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
84 arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
85 const long_index_t b1_batch_offset =
86 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
87 arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
88 const long_index_t c_batch_offset =
89 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
90 arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
91
92 GridwiseGemm::template Run<HasMainKBlockLoop>(
93 arg_ptr[group_id].p_a_grid_ + a_batch_offset,
94 arg_ptr[group_id].p_b_grid_ + b_batch_offset,
95 arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
96 arg_ptr[group_id].p_c_grid_ + c_batch_offset,
97 p_shared,
98 a_element_op,
99 b_element_op,
100 acc_element_op,
101 b1_element_op,
102 c_element_op,
103 arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
104 arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
105 arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
106 arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
107 arg_ptr[group_id].block_2_ctile_map_,
108 arg_ptr[group_id].c0_matrix_mask_);
109 }
110#else
111 ignore = group_kernel_args;
112 ignore = group_count;
113 ignore = a_element_op;
114 ignore = b_element_op;
115 ignore = acc_element_op;
116 ignore = b1_element_op;
117 ignore = c_element_op;
118#endif // end of if (defined(__gfx9__))
119}
120
121// Computes C = A * B0 * B1
122// ^^^^^^ (Acc0)
123// ^^^^^^^^^^^ (Acc1)
124template <index_t NumDimG,
125 index_t NumDimM,
126 index_t NumDimN,
127 index_t NumDimK,
128 index_t NumDimO, // NumDimGemm1N
129 typename ADataType,
130 typename BDataType,
131 typename B1DataType,
132 typename CDataType,
133 typename Acc0BiasDataType,
134 typename Acc1BiasDataType,
135 typename GemmAccDataType,
136 typename CShuffleDataType,
137 typename AElementwiseOperation,
138 typename BElementwiseOperation,
139 typename AccElementwiseOperation,
140 typename B1ElementwiseOperation,
141 typename CElementwiseOperation,
142 GemmSpecialization GemmSpec,
147 index_t NumGemmKPrefetchStage,
148 index_t BlockSize,
149 index_t MPerBlock,
150 index_t NPerBlock, // Gemm0NPerBlock
151 index_t KPerBlock, // Gemm0KPerBlock
152 index_t Gemm1NPerBlock,
153 index_t Gemm1KPerBlock,
154 index_t AK1,
155 index_t BK1,
156 index_t B1K1,
157 index_t MPerXDL,
158 index_t NPerXDL,
159 index_t MXdlPerWave,
160 index_t NXdlPerWave,
161 index_t Gemm1NXdlPerWave,
162 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
163 typename ABlockTransferThreadClusterArrangeOrder,
164 typename ABlockTransferSrcAccessOrder,
165 index_t ABlockTransferSrcVectorDim,
166 index_t ABlockTransferSrcScalarPerVector,
167 index_t ABlockTransferDstScalarPerVector_AK1,
168 bool ABlockLdsExtraM,
169 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
170 typename BBlockTransferThreadClusterArrangeOrder,
171 typename BBlockTransferSrcAccessOrder,
172 index_t BBlockTransferSrcVectorDim,
173 index_t BBlockTransferSrcScalarPerVector,
174 index_t BBlockTransferDstScalarPerVector_BK1,
175 bool BBlockLdsExtraN,
176 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
177 typename B1BlockTransferThreadClusterArrangeOrder,
178 typename B1BlockTransferSrcAccessOrder,
179 index_t B1BlockTransferSrcVectorDim,
180 index_t B1BlockTransferSrcScalarPerVector,
181 index_t B1BlockTransferDstScalarPerVector_BK1,
182 bool B1BlockLdsExtraN,
183 index_t CShuffleMXdlPerWavePerShuffle,
184 index_t CShuffleNXdlPerWavePerShuffle,
185 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
186 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
187 MaskingSpecialization MaskingSpec,
191 NumDimM,
192 NumDimN,
193 NumDimK,
194 NumDimO,
195 ADataType,
196 BDataType,
197 B1DataType,
198 CDataType,
199 Acc0BiasDataType,
200 Acc1BiasDataType,
201 AElementwiseOperation,
202 BElementwiseOperation,
203 AccElementwiseOperation,
204 B1ElementwiseOperation,
205 CElementwiseOperation,
206 MaskingSpec>
207{
208 static constexpr auto MXdlPerWave64 =
209 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
210 static constexpr auto MXdlPerWave32 =
211 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
212
213 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
214 "Number of dimension must be greater than 0");
215
216 static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
217 static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
218
219 // TODO ANT: implement bias combination
220 static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
221
222#if 0
223 // TODO ANT: use alias
224 static constexpr index_t NumDimGemm0M = NumDimM;
225 static constexpr index_t NumDimGemm0N = NumDimN;
226 static constexpr index_t NumDimGemm0K = NumDimK;
227 static constexpr index_t NumDimGemm1M = NumDimM;
228 static constexpr index_t NumDimGemm1N = NumDimO;
229 static constexpr index_t NumDimGemm1K = NumDimN;
230#endif
231
234 NumDimM,
235 NumDimN,
236 NumDimK,
237 NumDimO,
238 ADataType,
239 BDataType,
240 B1DataType,
241 CDataType,
242 Acc0BiasDataType,
243 Acc1BiasDataType,
244 AElementwiseOperation,
245 BElementwiseOperation,
246 AccElementwiseOperation,
247 B1ElementwiseOperation,
248 CElementwiseOperation,
249 MaskingSpec>::ProblemDesc;
250
251 static constexpr auto I0 = Number<0>{};
252 static constexpr auto I1 = Number<1>{};
253 static constexpr auto I2 = Number<2>{};
254
258 GemmSpec,
259 ASpec,
260 BSpec,
261 B1Spec,
262 CSpec>;
263
264 static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
265 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
266 {
268 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
269 Number<AK1>{});
270 }
271
272 static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
273 const std::vector<index_t>& b_gs_ns_ks_strides_vec)
274 {
276 Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
277 Number<BK1>{});
278 }
279
280 static auto
281 MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
282 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
283 {
285 Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
286 b1_gs_gemm1ns_gemm1ks_strides_vec),
287 Number<B1K1>{});
288 }
289
298
299 constexpr static auto make_MaskOutPredicate()
300 {
301 if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
302 {
303 return MaskDisabledPredicate{};
304 }
305 else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
306 {
308 }
309 }
311
313 {
315 const BGridDesc_G_N_K& b_grid_desc_g_n_k,
316 const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
317 const CGridDesc_G_M_N& c_grid_desc_g_m_n)
318 : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
319 b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
320 b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
321 c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
322 {
323 }
324
325 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
326 {
327 return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
328 }
329
330 __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
331 {
332 return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
333 }
334
335 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
336 {
337 return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
338 }
339
340 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
341 {
342 return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
343 }
344
345 private:
346 AGridDesc_G_M_K a_grid_desc_g_m_k_;
347 BGridDesc_G_N_K b_grid_desc_g_n_k_;
348 B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
349 CGridDesc_G_M_N c_grid_desc_g_m_n_;
350 };
351
352 // GridwiseGemm
353 template <index_t MXdlPerWave_>
355 ADataType, // TODO: distinguish A/B datatype
356 GemmAccDataType,
357 CShuffleDataType,
358 CDataType,
359 AElementwiseOperation,
360 BElementwiseOperation,
361 AccElementwiseOperation,
362 B1ElementwiseOperation,
363 CElementwiseOperation,
369 NumGemmKPrefetchStage,
370 BlockSize,
371 MPerBlock,
372 NPerBlock,
373 KPerBlock,
374 Gemm1NPerBlock,
375 Gemm1KPerBlock,
376 AK1,
377 BK1,
378 B1K1,
379 MPerXDL,
380 NPerXDL,
381 MXdlPerWave_,
382 NXdlPerWave,
383 Gemm1NXdlPerWave,
384 ABlockTransferThreadClusterLengths_AK0_M_AK1,
385 ABlockTransferThreadClusterArrangeOrder,
386 ABlockTransferSrcAccessOrder,
387 ABlockTransferSrcVectorDim,
388 ABlockTransferSrcScalarPerVector,
389 ABlockTransferDstScalarPerVector_AK1,
390 true,
391 ABlockLdsExtraM,
392 BBlockTransferThreadClusterLengths_BK0_N_BK1,
393 BBlockTransferThreadClusterArrangeOrder,
394 BBlockTransferSrcAccessOrder,
395 BBlockTransferSrcVectorDim,
396 BBlockTransferSrcScalarPerVector,
397 BBlockTransferDstScalarPerVector_BK1,
398 true,
399 BBlockLdsExtraN,
400 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
401 B1BlockTransferThreadClusterArrangeOrder,
402 B1BlockTransferSrcAccessOrder,
403 B1BlockTransferSrcVectorDim,
404 B1BlockTransferSrcScalarPerVector,
405 B1BlockTransferDstScalarPerVector_BK1,
406 false,
407 B1BlockLdsExtraN,
408 CShuffleMXdlPerWavePerShuffle,
409 CShuffleNXdlPerWavePerShuffle,
410 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
411 CShuffleBlockTransferScalarPerVector_NPerBlock,
412 LoopSched,
417
419
447
449 {
450 // lengths for the last dimensions of overall problem for sanity check of vector load/store
451 std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
452
453 // strides for the last dimensions of each tensor for sanity check of vector load/store
454 std::vector<index_t> a_mz_kz_strides_;
455 std::vector<index_t> b_nz_kz_strides_;
456 std::vector<index_t> b1_nz_kz_strides_;
457 std::vector<index_t> c_mz_gemm1nz_strides_;
458
459 // for gridwise gemm check
461 };
462
463 // Argument
464 // FIXME: constness
465 struct Argument : public BaseArgument
466 {
467 Argument(std::vector<const void*> p_a_vec,
468 std::vector<const void*> p_b_vec,
469 std::vector<const void*> p_b1_vec,
470 std::vector<void*> p_c_vec,
471 std::vector<std::vector<const void*>> p_acc0_biases_vec,
472 std::vector<std::vector<const void*>> p_acc1_biases_vec,
473 std::vector<ProblemDesc> problem_desc_vec,
474 AElementwiseOperation a_element_op,
475 BElementwiseOperation b_element_op,
476 AccElementwiseOperation acc_element_op,
477 B1ElementwiseOperation b1_element_op,
478 CElementwiseOperation c_element_op)
479 : a_element_op_{a_element_op},
480 b_element_op_{b_element_op},
481 acc_element_op_{acc_element_op},
482 b1_element_op_{b1_element_op},
483 c_element_op_{c_element_op}
484 {
485 // TODO ANT: implement bias addition
486 group_count_ = problem_desc_vec.size();
487
488 if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
489 group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size()))
490 {
491 throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
492 }
493
494 if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
495 {
496 throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
497 }
498
499 grid_size_ = 0;
500
501 for(std::size_t i = 0; i < group_count_; i++)
502 {
503 const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
504 const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
505 const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
506 const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
507
508 const auto& problem_desc = problem_desc_vec[i];
509
510 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
511 problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
512 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
513 problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
514 const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
515 problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
516 const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
517 problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
518
519 const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
520 problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
521 const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
522 problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
523 const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
524 problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
525 const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
526 problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
527
528 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
530 c_grid_desc_m_n);
531
532 const index_t BlockStart = grid_size_;
533 const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
534 const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
535 const index_t grid_size_grp =
536 block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
537 const index_t BlockEnd = grid_size_ + grid_size_grp;
538
539 // batch stride
540 const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
541 a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
542
543 // C0 mask
544 const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
545
546 grid_size_ += grid_size_grp;
547
548 // for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
549 // so on
550 if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
551 problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
552 problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
553 problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
554 {
555 throw std::runtime_error(
556 "wrong! number of biases in function argument does not "
557 "match that in template argument");
558 }
559
560 group_kernel_args_.push_back({p_a_grid,
561 p_b_grid,
562 p_b1_grid,
563 p_c_grid,
564 a_grid_desc_ak0_m_ak1,
565 b_grid_desc_bk0_n_bk1,
566 b1_grid_desc_bk0_n_bk1,
567 c_grid_desc_mblock_mperblock_nblock_nperblock,
568 block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
569 compute_base_ptr_of_batch,
570 c0_matrix_mask,
571 block_2_ctile_map,
572 BlockStart,
573 BlockEnd});
574
575 group_device_args_.push_back(
576 {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
577 problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
578 problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
579 problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
580 {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
581 problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
582 {problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
583 problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
584 {problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
585 problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
586 {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
587 problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
588 c_grid_desc_m_n});
589 }
590 }
591
592 std::vector<GroupKernelArg> group_kernel_args_;
593 std::vector<GroupDeviceArg> group_device_args_;
594
595 std::size_t group_count_;
597
598 AElementwiseOperation a_element_op_;
599 BElementwiseOperation b_element_op_;
600 AccElementwiseOperation acc_element_op_;
601 B1ElementwiseOperation b1_element_op_;
602 CElementwiseOperation c_element_op_;
603 };
604
605 // Invoker
606 struct Invoker : public BaseInvoker
607 {
609
610 template <typename GridwiseGemm>
611 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
612 {
614 {
615 throw std::runtime_error("wrong! unsupported argument");
616 }
617
618 bool all_has_main_k_block_loop = true;
619 bool some_has_main_k_block_loop = false;
620 for(std::size_t i = 0; i < arg.group_count_; i++)
621 {
622 const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
623 arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
624 const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
625 all_has_main_k_block_loop &= y;
626 some_has_main_k_block_loop |= y;
627 }
628
629 hipGetErrorString(
630 hipMemcpyWithStream(arg.p_workspace_,
631 arg.group_kernel_args_.data(),
632 arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
633 hipMemcpyHostToDevice,
634 stream_config.stream_id_));
635
636 float ave_time = 0;
637
638 auto launch_kernel = [&](auto has_main_k_block_loop_) {
639 const auto kernel =
641 GroupKernelArg,
642 AElementwiseOperation,
643 BElementwiseOperation,
644 AccElementwiseOperation,
645 B1ElementwiseOperation,
646 CElementwiseOperation,
647 has_main_k_block_loop_>;
648
650 stream_config,
651 kernel,
652 dim3(arg.grid_size_),
653 dim3(BlockSize),
654 0,
656 arg.group_count_,
657 arg.a_element_op_,
658 arg.b_element_op_,
659 arg.acc_element_op_,
660 arg.b1_element_op_,
661 arg.c_element_op_);
662 };
663
664 // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
665 // to concern Gemm0's loop
666 if(all_has_main_k_block_loop)
667 {
668 ave_time = launch_kernel(integral_constant<bool, true>{});
669 }
670 else if(!some_has_main_k_block_loop)
671 {
672 ave_time = launch_kernel(integral_constant<bool, false>{});
673 }
674 else
675 {
676 throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
677 "has_main_k_block_loop or no_main_k_block_loop");
678 }
679
680 return ave_time;
681 }
682
683 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
684 {
685 if(get_warp_size() == 64)
686 {
687 if constexpr(MXdlPerWave64 > 0)
688 {
689 return RunImp<GridwiseGemm64>(arg, stream_config);
690 }
691 }
692 else
693 {
694 if constexpr(MXdlPerWave32 > 0)
695 {
696 return RunImp<GridwiseGemm32>(arg, stream_config);
697 }
698 }
699 return 0;
700 }
701
702 // polymorphic
703 float Run(const BaseArgument* p_arg,
704 const StreamConfig& stream_config = StreamConfig{}) override
705 {
706 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
707 }
708 };
709
710 static constexpr bool IsValidCompilationParameter()
711 {
712 // TODO: properly implement this check
713 return true;
714 }
715
716 static bool IsSupportedArgument(const Argument& arg)
717 {
719 {
720 return false;
721 }
722 // TODO ANT: Check if tensor specialization & strides mismatch
723
724 bool all_has_main_k_block_loop = true;
725 bool some_has_main_k_block_loop = false;
726
727 for(std::size_t i = 0; i < arg.group_count_; i++)
728 {
729 const auto& kernel_arg = arg.group_kernel_args_[i];
730 const auto& device_arg = arg.group_device_args_[i];
731
732 // Check if C permute dimension matches GEMM + GEMM shape
733 const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
734 const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
735 const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
736 const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
737 if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
738 {
739 return false;
740 }
741
742 // Check if having main loop
743 const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
744 kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
746 all_has_main_k_block_loop &= y;
747 some_has_main_k_block_loop |= y;
748
749 // Note: we need raw lengths since threadwise copy can not handle vector load when
750 // part of vector is out of bounds
751 const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
752 const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
753 const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
754 const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
755
756 // Check scalar per vector requirement
757 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
758 const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
759 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
760 const auto c_extent_lowest = Gemm1NzRaw;
761
762 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
763 b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
764 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
765 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
766 {
767 return false;
768 }
769
770 // Check vector load/store requirement
771 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
772 ? device_arg.a_mz_kz_strides_[1]
773 : device_arg.a_mz_kz_strides_[0];
774 const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
775 ? device_arg.b_nz_kz_strides_[1]
776 : device_arg.b_nz_kz_strides_[0];
777 const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
778 ? device_arg.b1_nz_kz_strides_[1]
779 : device_arg.b1_nz_kz_strides_[0];
780 const auto c_stride_lowest =
781 device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
782 // contiguous
783
784 if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
785 c_stride_lowest == 1))
786 {
787 return false;
788 }
789
790 bool valid = false;
791 if(get_warp_size() == 64)
792 {
793 if constexpr(MXdlPerWave64 > 0)
794 {
795 valid = GridwiseGemm64::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_,
796 kernel_arg.b_grid_desc_bk0_n_bk1_,
797 kernel_arg.b1_grid_desc_bk0_n_bk1_,
798 device_arg.c_grid_desc_m_n_,
799 kernel_arg.block_2_ctile_map_);
800 }
801 }
802 else
803 {
804 if constexpr(MXdlPerWave32 > 0)
805 {
806 valid = GridwiseGemm32::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_,
807 kernel_arg.b_grid_desc_bk0_n_bk1_,
808 kernel_arg.b1_grid_desc_bk0_n_bk1_,
809 device_arg.c_grid_desc_m_n_,
810 kernel_arg.block_2_ctile_map_);
811 }
812 }
813 if(!valid)
814 return false;
815 }
816
817 // all gemm problems have to simultaneously meet has_main_k_block_loop or
818 // no_main_k_block_loop
819 if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop))
820 {
821 return false;
822 }
823
824 return true;
825 }
826
827 // polymorphic
828 bool IsSupportedArgument(const BaseArgument* p_arg) override
829 {
830 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
831 }
832
833 static auto MakeArgument(std::vector<const void*> p_a_vec,
834 std::vector<const void*> p_b_vec,
835 std::vector<const void*> p_b1_vec,
836 std::vector<void*> p_c_vec,
837 std::vector<std::vector<const void*>> p_acc0_biases_vec,
838 std::vector<std::vector<const void*>> p_acc1_biases_vec,
839 std::vector<ProblemDesc> problem_desc_vec,
840 AElementwiseOperation a_element_op,
841 BElementwiseOperation b_element_op,
842 AccElementwiseOperation acc_element_op,
843 B1ElementwiseOperation b1_element_op,
844 CElementwiseOperation c_element_op)
845 {
846 return Argument{p_a_vec,
847 p_b_vec,
848 p_b1_vec,
849 p_c_vec,
850 p_acc0_biases_vec,
851 p_acc1_biases_vec,
852 problem_desc_vec,
853 a_element_op,
854 b_element_op,
855 acc_element_op,
856 b1_element_op,
857 c_element_op};
858 }
859
860 static auto MakeInvoker() { return Invoker{}; }
861
862 // polymorphic
863 std::unique_ptr<BaseArgument>
864 MakeArgumentPointer(std::vector<const void*> p_a_vec,
865 std::vector<const void*> p_b_vec,
866 std::vector<const void*> p_b1_vec,
867 std::vector<void*> p_c_vec,
868 std::vector<std::vector<const void*>> p_acc0_biases_vec,
869 std::vector<std::vector<const void*>> p_acc1_biases_vec,
870 std::vector<ProblemDesc> problem_desc_vec,
871 AElementwiseOperation a_element_op,
872 BElementwiseOperation b_element_op,
873 AccElementwiseOperation acc_element_op,
874 B1ElementwiseOperation b1_element_op,
875 CElementwiseOperation c_element_op) override
876 {
877 return std::make_unique<Argument>(p_a_vec,
878 p_b_vec,
879 p_b1_vec,
880 p_c_vec,
881 p_acc0_biases_vec,
882 p_acc1_biases_vec,
883 problem_desc_vec,
884 a_element_op,
885 b_element_op,
886 acc_element_op,
887 b1_element_op,
888 c_element_op);
889 }
890
891 // polymorphic
892 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
893 {
894 return std::make_unique<Invoker>(Invoker{});
895 }
896
897 // polymorphic
898 std::string GetTypeString() const override
899 {
900 auto str = std::stringstream();
901
902 // clang-format off
903 str << "DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle"
904 << "<"
905 << BlockSize << ", "
906 << MPerBlock << ", "
907 << NPerBlock << ", "
908 << KPerBlock << ", "
909 << AK1 << ", "
910 << BK1 << ", "
911 << MPerBlock << ", "
912 << Gemm1NPerBlock << ", "
913 << Gemm1KPerBlock << ", "
914 << B1K1 << ", "
915 << getGemmSpecializationString(GemmSpec) << ", "
916 << "ASpec" << getTensorSpecializationString(ASpec) << ", "
917 << "B0Spec" << getTensorSpecializationString(BSpec) << ", "
918 << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
919 << "CSpec" << getTensorSpecializationString(CSpec) << ", "
920 << getMaskingSpecializationString(MaskingSpec) << ", "
921 << MPerXDL << ", "
922 << NPerXDL << ", "
923 << MXdlPerWave << ", "
924 << NXdlPerWave << ", "
925 << ABlockTransferSrcScalarPerVector << ", "
926 << BBlockTransferSrcScalarPerVector << ", "
927 << CShuffleMXdlPerWavePerShuffle << ", "
928 << CShuffleNXdlPerWavePerShuffle
929 << ">";
930 // clang-format on
931
932 return str.str();
933 }
934
935 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
936 {
937 return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
938 }
939};
940
941} // namespace device
942} // namespace tensor_operation
943} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(const void CK_CONSTANT_ADDRESS_SPACE *group_kernel_args, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:37
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
int64_t long_index_t
Definition ck.hpp:300
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
Definition ck/stream_config.hpp:10
Gridwise gemm + softmax + gemm fusion.
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:87
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:293
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:231
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:319
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:285
Definition block_to_ctile_map.hpp:872
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:198
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:154
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:193
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm.hpp:248
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:159
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:274
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:279
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition masking_specialization.hpp:57
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:313
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const BGridDesc_G_N_K &b_grid_desc_g_n_k, const B1GridDesc_G_N_K &b1_grid_desc_g_n_k, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:314
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:325
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:330
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:335
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:340
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:466
std::vector< GroupDeviceArg > group_device_args_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:593
AccElementwiseOperation acc_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:600
index_t grid_size_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:596
std::size_t group_count_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:595
CElementwiseOperation c_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:602
Argument(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:467
B1ElementwiseOperation b1_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:601
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:598
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:599
std::vector< GroupKernelArg > group_kernel_args_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:592
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:449
std::vector< index_t > b_nz_kz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:455
std::vector< index_t > a_mz_kz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:454
std::vector< index_t > c_mz_gemm1nz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:457
std::vector< index_t > raw_lengths_mz_nz_kz_gemm1nz_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:451
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:460
std::vector< index_t > b1_nz_kz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:456
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:421
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:431
index_t block_start_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:445
C0MatrixMask c0_matrix_mask_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:440
CDataType * p_c_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:426
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:429
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:430
index_t num_blocks_per_batch_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:436
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:437
const ADataType * p_a_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:423
GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:433
const BDataType * p_b_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:424
index_t block_end_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:445
const B1DataType * p_b1_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:425
Block2CTileMap block_2_ctile_map_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:443
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:607
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:703
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:683
DeviceOp::Argument Argument
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:608
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:611
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:207
decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})) AGridDesc_AK0_M_AK1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:290
static auto MakeArgument(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:833
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:294
std::string GetTypeString() const override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:898
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:828
static constexpr index_t NumAcc1Bias
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:217
static constexpr index_t NumAcc0Bias
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:216
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:710
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:716
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:864
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:892
static constexpr auto I1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:252
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:264
static constexpr auto I0
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:251
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:310
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_K
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:296
static auto MakeInvoker()
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:860
decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})) BGridDesc_BK0_N_BK1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:291
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:297
static constexpr auto I2
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:253
OffsettedBlockToCTileMap< typename GridwiseGemm64::DefaultBlock2CTileMap > Block2CTileMap
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:418
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:415
static constexpr auto make_MaskOutPredicate()
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:299
DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle DeviceOp
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:232
TransformBatchedContractionContractionToBatchedGemmGemm< Sequence< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO >, Sequence< MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock >, GemmSpec, ASpec, BSpec, B1Spec, CSpec > Transform
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:255
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:272
decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})) B1GridDesc_BK0_N_BK1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:292
typename DeviceGroupedGemmSoftmaxGemmPermute< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, BDataType, B1DataType, CDataType, Acc0BiasDataType, Acc1BiasDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, MaskingSpec >::ProblemDesc ProblemDesc
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:233
static constexpr auto MXdlPerWave64
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:208
static auto MakeB1GridDescriptor_BK0_N_BK1(const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides_vec)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:281
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:935
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:416
static constexpr auto MXdlPerWave32
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:210
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:293
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) BGridDesc_G_N_K
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:295
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle > GridwiseGemmBase
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:354
Definition device_grouped_gemm_softmax_gemm_permute.hpp:34
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43