device_gemm_xdl_cshuffle_v2.hpp Source File

device_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_v2.hpp Source File
device_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24// version currently has compiler issues with register spill which further causes validation
25// failures.
26template <typename ALayout,
27 typename BLayout,
28 typename CLayout,
29 typename ADataType,
30 typename BDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t NumGemmKPrefetchStage,
39 index_t BlockSize,
40 index_t MPerBlock,
41 index_t NPerBlock,
42 index_t KPerBlock,
43 index_t AK1,
44 index_t BK1,
45 index_t MPerXDL,
46 index_t NPerXDL,
47 index_t MXdlPerWave,
48 index_t NXdlPerWave,
49 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
52 index_t ABlockTransferSrcVectorDim,
53 index_t ABlockTransferSrcScalarPerVector,
54 index_t ABlockTransferDstScalarPerVector_AK1,
55 bool ABlockLdsExtraM,
56 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 index_t BBlockTransferSrcVectorDim,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t BBlockTransferDstScalarPerVector_BK1,
62 bool BBlockLdsExtraN,
63 index_t CShuffleMXdlPerWavePerShuffle,
64 index_t CShuffleNXdlPerWavePerShuffle,
65 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69 typename ComputeTypeA = CDataType,
70 typename ComputeTypeB = ComputeTypeA>
71struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm<ALayout,
72 BLayout,
73 CLayout,
74 ADataType,
75 BDataType,
76 CDataType,
77 AElementwiseOperation,
78 BElementwiseOperation,
79 CElementwiseOperation>
80{
83 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
84 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
85
86 static constexpr auto I0 = Number<0>{};
87 static constexpr auto I1 = Number<1>{};
88 static constexpr auto I2 = Number<2>{};
89
90 // GridwiseGemm
91 template <index_t NXdlPerWave_>
93 ALayout,
94 BLayout,
95 CLayout,
96 ADataType,
97 BDataType,
98 GemmAccDataType,
99 CShuffleDataType,
100 CDataType,
101 AElementwiseOperation,
102 BElementwiseOperation,
103 CElementwiseOperation,
104 GemmSpec,
106 NumGemmKPrefetchStage,
107 BlockSize,
108 MPerBlock,
109 NPerBlock,
110 KPerBlock,
111 AK1,
112 BK1,
113 MPerXDL,
114 NPerXDL,
115 MXdlPerWave,
116 NXdlPerWave_,
117 ABlockTransferThreadClusterLengths_AK0_M_AK1,
118 ABlockTransferThreadClusterArrangeOrder,
119 ABlockTransferSrcAccessOrder,
120 ABlockTransferSrcVectorDim,
121 ABlockTransferSrcScalarPerVector,
122 ABlockTransferDstScalarPerVector_AK1,
123 false,
124 ABlockLdsExtraM,
125 BBlockTransferThreadClusterLengths_BK0_N_BK1,
126 BBlockTransferThreadClusterArrangeOrder,
127 BBlockTransferSrcAccessOrder,
128 BBlockTransferSrcVectorDim,
129 BBlockTransferSrcScalarPerVector,
130 BBlockTransferDstScalarPerVector_BK1,
131 false,
132 BBlockLdsExtraN,
133 CShuffleMXdlPerWavePerShuffle,
134 CShuffleNXdlPerWavePerShuffle,
135 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136 CShuffleBlockTransferScalarPerVector_NPerBlock,
137 LoopSched,
138 PipelineVer,
139 ComputeTypeA,
140 ComputeTypeB>;
143
144 using Argument = typename GridwiseGemm64::Argument;
145
146 // Invoker
147 struct Invoker : public BaseInvoker
148 {
149 template <typename GridwiseGemm>
150 float RunImp(const typename GridwiseGemm::Argument& arg,
151 const StreamConfig& stream_config = StreamConfig{})
152 {
153 if(stream_config.log_level_ > 0)
154 {
155 arg.Print();
156 }
157
158 if(!GridwiseGemm::CheckValidity(arg))
159 {
160 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
161 }
162
163 index_t gdx, gdy, gdz;
164 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
165
166 float ave_time = 0;
167 const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
168
169 if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3)
170 {
172 ave_time = launch_and_time_kernel(
173 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
174 }
175 else
176 {
178 ave_time = launch_and_time_kernel(
179 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
180 }
181
182 return ave_time;
183 }
184
186
187 // polymorphic
188 float Run(const BaseArgument* p_arg,
189 const StreamConfig& stream_config = StreamConfig{}) override
190 {
191 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
192 }
193 };
194
195 static constexpr bool IsValidCompilationParameter()
196 {
197 // TODO: properly implement this check
198 return true;
199 }
200
201 static bool IsSupportedArgument(const Argument& arg)
202 {
204 {
205 return false;
206 }
207 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
208 GemmSpec == GemmSpecialization::NKPadding ||
209 GemmSpec == GemmSpecialization::MNKPadding ||
210 GemmSpec == GemmSpecialization::KPadding))
211 {
212 return false;
213 }
214
215 if(get_warp_size() == 64)
216 {
217 if constexpr(NXdlPerWave64 > 0)
218 {
220 }
221 }
222 else
223 {
224 if constexpr(NXdlPerWave32 > 0)
225 {
227 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
228 }
229 }
230 return false;
231 }
232
233 // polymorphic
234 bool IsSupportedArgument(const BaseArgument* p_arg) override
235 {
236 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
237 }
238
239 static auto MakeArgument(const ADataType* p_a,
240 const BDataType* p_b,
241 CDataType* p_c,
242 index_t M,
243 index_t N,
244 index_t K,
245 index_t StrideA,
246 index_t StrideB,
247 index_t StrideC,
248 AElementwiseOperation,
249 BElementwiseOperation,
250 CElementwiseOperation)
251 {
252 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
253 }
254
255 static auto MakeInvoker() { return Invoker{}; }
256
257 // polymorphic
258 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
259 const void* p_b,
260 void* p_c,
261 index_t M,
262 index_t N,
263 index_t K,
264 index_t StrideA,
265 index_t StrideB,
266 index_t StrideC,
267 AElementwiseOperation,
268 BElementwiseOperation,
269 CElementwiseOperation) override
270 {
271 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
272 static_cast<const BDataType*>(p_b),
273 static_cast<CDataType*>(p_c),
274 M,
275 N,
276 K,
277 StrideA,
278 StrideB,
279 StrideC);
280 }
281
282 // polymorphic
283 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
284 {
285 return std::make_unique<Invoker>(Invoker{});
286 }
287
288 // polymorphic
289 std::string GetTypeString() const override
290 {
291 auto str = std::stringstream();
292
293 std::map<LoopScheduler, std::string> LoopSchedToString{
294 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
295
296 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
297 {PipelineVersion::v2, "v2"}};
298
299 // clang-format off
300 str << "DeviceGemm_Xdl_CShuffleV2"
301 << "<"
302 << getGemmSpecializationString(GemmSpec) << ", "
303 << BlockSize << ", "
304 << MPerBlock << ", "
305 << NPerBlock << ", "
306 << KPerBlock << ", "
307 << AK1 << ", "
308 << BK1 << ", "
309 << MPerXDL << ", "
310 << NPerXDL << ", "
311 << MXdlPerWave << ", "
312 << NXdlPerWave << ", "
313 << ABlockTransferSrcScalarPerVector << ", "
314 << BBlockTransferSrcScalarPerVector << ", "
315 << CShuffleMXdlPerWavePerShuffle << ", "
316 << CShuffleNXdlPerWavePerShuffle
317 << ">"
318 << " LoopScheduler: "
319 << LoopSchedToString[LoopSched] << ", "
320 << "PipelineVersion: "
321 << PipelineVersionToString[PipelineVer];
322 // clang-format on
323
324 return str.str();
325 }
326};
327
328} // namespace device
329} // namespace tensor_operation
330} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:26
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:126
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v2.hpp:148
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v2.hpp:188
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v2.hpp:150
Definition device_gemm_xdl_cshuffle_v2.hpp:80
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle_v2.hpp:239
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v2.hpp:141
GridwiseGemm_xdl_cshuffle_v2< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v2.hpp:92
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v2.hpp:234
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v2.hpp:144
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v2.hpp:83
static constexpr auto I0
Definition device_gemm_xdl_cshuffle_v2.hpp:86
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v2.hpp:289
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle_v2.hpp:258
static constexpr auto I1
Definition device_gemm_xdl_cshuffle_v2.hpp:87
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v2.hpp:195
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v2.hpp:283
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v2.hpp:255
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v2.hpp:201
DeviceGemm_Xdl_CShuffleV2 DeviceOp
Definition device_gemm_xdl_cshuffle_v2.hpp:81
static constexpr auto I2
Definition device_gemm_xdl_cshuffle_v2.hpp:88
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v2.hpp:142
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v2.hpp:84
Definition device_gemm.hpp:22