device_gemm_xdl_splitk_c_shuffle.hpp Source File

device_gemm_xdl_splitk_c_shuffle.hpp Source File#

Composable Kernel: device_gemm_xdl_splitk_c_shuffle.hpp Source File
device_gemm_xdl_splitk_c_shuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename ADataType,
24 typename BDataType,
25 typename CDataType,
26 typename AccDataType,
27 typename ALayout,
28 typename BLayout,
29 typename CLayout,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
33 GemmSpecialization GemmSpec,
34 ck::index_t BlockSize,
35 ck::index_t MPerBlock,
36 ck::index_t NPerBlock,
37 ck::index_t K0PerBlock,
38 ck::index_t K1,
39 ck::index_t MPerXDL,
40 ck::index_t NPerXDL,
41 ck::index_t MXdlPerWave,
42 ck::index_t NXdlPerWave,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferSrcScalarPerVector,
48 ck::index_t ABlockTransferDstScalarPerVector_K1,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 ck::index_t BBlockTransferSrcVectorDim,
54 ck::index_t BBlockTransferSrcScalarPerVector,
55 ck::index_t BBlockTransferDstScalarPerVector_K1,
56 bool BBlockLdsAddExtraN,
57 index_t CShuffleMRepeatPerShuffle,
58 index_t CShuffleNRepeatPerShuffle,
59 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
61 typename ComputeType = CDataType,
64 typename LDSTypeA = ComputeType,
65 typename LDSTypeB = ComputeType>
66
68 BLayout,
69 CLayout,
70 ADataType,
71 BDataType,
72 CDataType,
73 AElementwiseOperation,
74 BElementwiseOperation,
75 CElementwiseOperation,
76 ComputeType>
77{
79 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
80 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
81
82 static constexpr auto I0 = Number<0>{};
83 static constexpr auto I1 = Number<1>{};
84 static constexpr auto I2 = Number<2>{};
85 static constexpr auto I3 = Number<3>{};
86
87 // TODO: should be exposed as Tparams.
88 static constexpr index_t NumGemmKPrefetchStage = 1;
89
90 using ComputeTypeA = ComputeType;
91 using ComputeTypeB = ComputeType;
92
93 template <index_t NXdlPerWave_>
95 BlockSize,
96 ADataType,
97 BDataType,
98 AccDataType,
99 CDataType,
100 ALayout,
101 BLayout,
102 CLayout,
103 AElementwiseOperation,
104 BElementwiseOperation,
105 CElementwiseOperation,
106 GemmSpec,
108 MPerBlock,
109 NPerBlock,
110 K0PerBlock,
111 MPerXDL,
112 NPerXDL,
113 K1,
114 MXdlPerWave,
115 NXdlPerWave_,
116 ABlockTransferThreadClusterLengths_K0_M_K1,
117 ABlockTransferThreadClusterArrangeOrder,
118 ABlockTransferSrcAccessOrder,
119 ABlockTransferSrcVectorDim,
120 ABlockTransferSrcScalarPerVector,
121 ABlockTransferDstScalarPerVector_K1,
122 false, // AThreadTransferSrcResetCoordinateAfterRun,
123 ABlockLdsAddExtraM,
124 BBlockTransferThreadClusterLengths_K0_N_K1,
125 BBlockTransferThreadClusterArrangeOrder,
126 BBlockTransferSrcAccessOrder,
127 BBlockTransferSrcVectorDim,
128 BBlockTransferSrcScalarPerVector,
129 BBlockTransferDstScalarPerVector_K1,
130 false, // BThreadTransferSrcResetCoordinateAfterRun,
131 BBlockLdsAddExtraN,
132 CShuffleMRepeatPerShuffle,
133 CShuffleNRepeatPerShuffle,
134 CBlockTransferScalarPerVector_NWaveNPerXDL,
135 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136 LoopSched,
137 PipelineVer,
140 LDSTypeA,
141 LDSTypeB>;
144
145 struct Argument : public GridwiseGemm64::Argument
146 {
147 Argument(const ADataType* p_a_grid_,
148 const BDataType* p_b_grid_,
149 CDataType* p_c_grid_,
150 index_t M_,
151 index_t N_,
152 index_t K_,
153 index_t StrideA_,
154 index_t StrideB_,
155 index_t StrideC_,
156 index_t MPadded_,
157 index_t NPadded_,
158 index_t KPadded_,
159 index_t K0Padded_,
160 index_t k_batch_,
161 AElementwiseOperation a_element_op_,
162 BElementwiseOperation b_element_op_,
163 CElementwiseOperation c_element_op_)
164 : GridwiseGemm64::Argument(p_a_grid_,
165 p_b_grid_,
166 p_c_grid_,
167 M_,
168 N_,
169 K_,
170 StrideA_,
171 StrideB_,
172 StrideC_,
173 MPadded_,
174 NPadded_,
175 KPadded_,
176 K0Padded_,
177 k_batch_),
178 a_element_op(a_element_op_),
179 b_element_op(b_element_op_),
180 c_element_op(c_element_op_)
181 {
182 }
183
184 AElementwiseOperation a_element_op;
185 BElementwiseOperation b_element_op;
186 CElementwiseOperation c_element_op;
187 };
188
190
191 // Invoker
192 struct Invoker : public BaseInvoker
193 {
194
195 void Print(const Argument& karg) { karg.Print(); }
196
197 template <typename GridwiseGemm>
198 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
199 {
200 if(stream_config.log_level_ > 0)
201 {
202 Print(arg);
203 }
204
205 typename GridwiseGemm::Argument karg(arg.p_a_grid,
206 arg.p_b_grid,
207 arg.p_c_grid,
208 arg.M,
209 arg.N,
210 arg.K,
211 arg.StrideA,
212 arg.StrideB,
213 arg.StrideC,
214 arg.MPadded,
215 arg.NPadded,
216 arg.KPadded,
217 arg.K0Padded,
218 arg.k_batch);
219 const auto kbatch = karg.k_batch;
220 if(!GridwiseGemm::CheckValidity(karg))
221 {
222 throw std::runtime_error(
223 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
224 "setting");
225 }
226
227 const auto b2c_map = DefaultBlock2CTileMap{};
228 index_t gdx, gdy, gdz;
229 ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
230 const auto K0Padded = karg.K0Padded;
231
232 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
233
234 float ave_time = 0;
235
236 const auto Run = [&](const auto& kernel) {
237 if(kbatch > 1)
238 hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
239 0,
240 karg.M * karg.N * sizeof(CDataType),
241 stream_config.stream_id_));
242
243 ave_time =
244 launch_and_time_kernel(stream_config,
245 kernel,
246 dim3(gdx, gdy, gdz),
247 dim3(BlockSize),
248 0,
249 static_cast<typename GridwiseGemm::Argument>(karg),
250 b2c_map,
251 arg.a_element_op,
252 arg.b_element_op,
253 arg.c_element_op);
254 };
255
256 if(has_main_k0_block_loop)
257 {
258 if(kbatch == 1)
259 {
260 const auto kernel =
262 true,
265 AElementwiseOperation,
266 BElementwiseOperation,
267 CElementwiseOperation>;
268
269 Run(kernel);
270 }
271 else
272 {
273 const auto kernel =
275 true,
278 AElementwiseOperation,
279 BElementwiseOperation,
280 CElementwiseOperation>;
281
282 Run(kernel);
283 }
284 }
285 else
286 {
287 if(kbatch == 1)
288 {
289 const auto kernel =
291 false,
294 AElementwiseOperation,
295 BElementwiseOperation,
296 CElementwiseOperation>;
297
298 Run(kernel);
299 }
300 else
301 {
302 const auto kernel =
304 false,
307 AElementwiseOperation,
308 BElementwiseOperation,
309 CElementwiseOperation>;
310
311 Run(kernel);
312 }
313 }
314
315 return ave_time;
316 }
317
319
320 // polymorphic
321 float Run(const BaseArgument* p_arg,
322 const StreamConfig& stream_config = StreamConfig{}) override
323 {
324 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
325 }
326 };
327
328 static constexpr bool IsValidCompilationParameter()
329 {
330 // TODO: properly implement this check
331 return true;
332 }
333
334 static bool IsSupportedArgument(const Argument& karg)
335 {
336 // gfx11 doesn't support float atomic
338 {
339 return false;
340 }
342 {
343 return false;
344 }
345 if(get_warp_size() == 64)
346 {
347 if constexpr(NXdlPerWave64 > 0)
348 {
350 }
351 }
352 else
353 {
354 if constexpr(NXdlPerWave32 > 0)
355 {
357 reinterpret_cast<const typename GridwiseGemm32::Argument&>(karg));
358 }
359 }
360 return false;
361 }
362
363 // polymorphic
364 bool IsSupportedArgument(const BaseArgument* p_arg) override
365 {
366 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
367 }
368
369 static auto MakeArgument(const ADataType* p_a,
370 const BDataType* p_b,
371 CDataType* p_c,
372 index_t M,
373 index_t N,
374 index_t K,
375 index_t StrideA,
376 index_t StrideB,
377 index_t StrideC,
378 AElementwiseOperation a_element_op,
379 BElementwiseOperation b_element_op,
380 CElementwiseOperation c_element_op,
381 index_t KBatch)
382 {
383 return Argument(p_a,
384 p_b,
385 p_c,
386 M,
387 N,
388 K,
389 StrideA,
390 StrideB,
391 StrideC,
396 KBatch,
397 a_element_op,
398 b_element_op,
399 c_element_op);
400 }
401
402 static auto MakeInvoker() { return Invoker{}; }
403
404 // polymorphic
405 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
406 const void* p_b,
407 void* p_c,
408 index_t M,
409 index_t N,
410 index_t K,
411 index_t StrideA,
412 index_t StrideB,
413 index_t StrideC,
414 AElementwiseOperation a_element_op,
415 BElementwiseOperation b_element_op,
416 CElementwiseOperation c_element_op,
417 ck::index_t KBatch = 1) override
418 {
419 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
420 static_cast<const BDataType*>(p_b),
421 static_cast<CDataType*>(p_c),
422 M,
423 N,
424 K,
425 StrideA,
426 StrideB,
427 StrideC,
432 KBatch,
433 a_element_op,
434 b_element_op,
435 c_element_op);
436 }
437
438 // polymorphic
439 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
440 {
441 return std::make_unique<Invoker>(Invoker{});
442 }
443
444 // polymorphic
445 std::string GetTypeString() const override
446 {
447 auto str = std::stringstream();
448
449 std::map<LoopScheduler, std::string> LoopSchedToString{
450 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
451
452 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
453 {PipelineVersion::v2, "v2"}};
454
455 str << GridwiseGemm64::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
456 << ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
457
458 return str.str();
459 }
460};
461
462} // namespace device
463} // namespace tensor_operation
464} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:33
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
Definition device_base.hpp:197
Definition device_gemm_splitk.hpp:26
Definition device_gemm_xdl_splitk_c_shuffle.hpp:146
AElementwiseOperation a_element_op
Definition device_gemm_xdl_splitk_c_shuffle.hpp:184
BElementwiseOperation b_element_op
Definition device_gemm_xdl_splitk_c_shuffle.hpp:185
CElementwiseOperation c_element_op
Definition device_gemm_xdl_splitk_c_shuffle.hpp:186
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:147
Definition device_gemm_xdl_splitk_c_shuffle.hpp:193
void Print(const Argument &karg)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:195
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_splitk_c_shuffle.hpp:198
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:321
Definition device_gemm_xdl_splitk_c_shuffle.hpp:77
ComputeType ComputeTypeA
Definition device_gemm_xdl_splitk_c_shuffle.hpp:90
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:369
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_splitk_c_shuffle.hpp:142
static constexpr auto I1
Definition device_gemm_xdl_splitk_c_shuffle.hpp:83
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl_splitk_c_shuffle.hpp:334
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_splitk_c_shuffle.hpp:328
static constexpr auto I0
Definition device_gemm_xdl_splitk_c_shuffle.hpp:82
typename GridwiseGemm64::DefaultBlock2CTileMap DefaultBlock2CTileMap
Definition device_gemm_xdl_splitk_c_shuffle.hpp:189
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:439
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_xdl_splitk_c_shuffle.hpp:94
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch=1) override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:405
static constexpr auto I2
Definition device_gemm_xdl_splitk_c_shuffle.hpp:84
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_splitk_c_shuffle.hpp:80
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:364
static constexpr auto I3
Definition device_gemm_xdl_splitk_c_shuffle.hpp:85
std::string GetTypeString() const override
Definition device_gemm_xdl_splitk_c_shuffle.hpp:445
static auto MakeInvoker()
Definition device_gemm_xdl_splitk_c_shuffle.hpp:402
ComputeType ComputeTypeB
Definition device_gemm_xdl_splitk_c_shuffle.hpp:91
static constexpr index_t NumGemmKPrefetchStage
Definition device_gemm_xdl_splitk_c_shuffle.hpp:88
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_splitk_c_shuffle.hpp:143
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_splitk_c_shuffle.hpp:79