device_batched_gemm_xdl.hpp Source File

device_batched_gemm_xdl.hpp Source File#

Composable Kernel: device_batched_gemm_xdl.hpp Source File
device_batched_gemm_xdl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23/*
24 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
25 *
26 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
27 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
28 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
29 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
30 * limitations.
31 *
32 * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
33 * returns the 2D index of the tile that it computes. \see
34 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
35 *
36 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
37 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
38 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
39 * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
40 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
41 * pointer offset into \p ComputePtrOffsetOfStridedBatch.
42 *
43 * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
44 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
45 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
46 *
47 */
48template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
49__global__ void
50#if CK_USE_LAUNCH_BOUNDS
52#endif
53 kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
54{
55#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
56 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
57 {
58 const index_t num_blocks_per_batch =
59 __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
60 const index_t g_idx =
61 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
62
63 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
64 static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
65 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
66 static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
67 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
68 static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
69
70 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
71
72 const auto a_grid_desc_k0_m_k1 =
73 amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
74 karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
75 const auto b_grid_desc_k0_n_k1 =
76 amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
77 karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
78 const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
79 karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
80
81 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid + a_batch_offset,
82 karg.p_b_grid + b_batch_offset,
83 karg.p_c_grid + c_batch_offset,
84 p_shared,
85 a_grid_desc_k0_m_k1,
86 b_grid_desc_k0_n_k1,
87 c_grid_desc_m_n);
88 }
89#else
90 ignore = karg;
91#endif
92}
93
94template <typename ADataType,
95 typename BDataType,
96 typename CDataType,
97 typename AccDataType,
98 typename ALayout,
99 typename BLayout,
100 typename CLayout,
101 typename AElementwiseOperation,
102 typename BElementwiseOperation,
103 typename CElementwiseOperation,
104 ck::index_t BlockSize,
105 ck::index_t MPerBlock,
106 ck::index_t NPerBlock,
107 ck::index_t K0PerBlock,
108 ck::index_t K1,
109 ck::index_t MPerXDL,
110 ck::index_t NPerXDL,
111 ck::index_t MXdlPerWave,
112 ck::index_t NXdlPerWave,
113 typename ABlockTransferThreadClusterLengths_K0_M_K1,
114 typename ABlockTransferThreadClusterArrangeOrder,
115 typename ABlockTransferSrcAccessOrder,
116 ck::index_t ABlockTransferSrcVectorDim,
117 ck::index_t ABlockTransferSrcScalarPerVector,
118 ck::index_t ABlockTransferDstScalarPerVector_K1,
119 bool ABlockLdsAddExtraM,
120 typename BBlockTransferThreadClusterLengths_K0_N_K1,
121 typename BBlockTransferThreadClusterArrangeOrder,
122 typename BBlockTransferSrcAccessOrder,
123 ck::index_t BBlockTransferSrcVectorDim,
124 ck::index_t BBlockTransferSrcScalarPerVector,
125 ck::index_t BBlockTransferDstScalarPerVector_K1,
126 bool BBlockLdsAddExtraN,
127 ck::index_t CThreadTransferSrcDstVectorDim,
128 ck::index_t CThreadTransferDstScalarPerVector,
129 ck::index_t NumGemmKPrefetchStage = 1,
133 BLayout,
134 CLayout,
135 ADataType,
136 BDataType,
137 CDataType,
138 AElementwiseOperation,
139 BElementwiseOperation,
140 CElementwiseOperation>
141{
143 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
144 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
145
146 static constexpr auto I0 = Number<0>{};
147 static constexpr auto I1 = Number<1>{};
148 static constexpr auto I2 = Number<2>{};
149
150 static constexpr auto K1Number = Number<K1>{};
151
153 {
155 index_t BatchStrideB,
156 index_t BatchStrideC)
157 : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
158 {
159 }
160
161 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
162 {
163 return g_idx * static_cast<long_index_t>(BatchStrideA_);
164 }
165
166 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
167 {
168 return g_idx * static_cast<long_index_t>(BatchStrideB_);
169 }
170
171 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
172 {
173 return g_idx * static_cast<long_index_t>(BatchStrideC_);
174 }
175
176 private:
177 index_t BatchStrideA_;
178 index_t BatchStrideB_;
179 index_t BatchStrideC_;
180 };
181
182 // GridwiseGemm
183 template <index_t NXdlPerWave_>
185 BlockSize,
186 ADataType, // TODO: distinguish A/B datatype
187 AccDataType,
188 CDataType,
190 ALayout,
191 BLayout,
192 CLayout,
193 AElementwiseOperation,
194 BElementwiseOperation,
195 CElementwiseOperation,
197 MPerBlock,
198 NPerBlock,
199 K0PerBlock,
200 MPerXDL,
201 NPerXDL,
202 K1,
203 MXdlPerWave,
204 NXdlPerWave_,
205 ABlockTransferThreadClusterLengths_K0_M_K1,
206 ABlockTransferThreadClusterArrangeOrder,
207 ABlockTransferSrcAccessOrder,
208 ABlockTransferSrcVectorDim,
209 ABlockTransferSrcScalarPerVector,
210 ABlockTransferDstScalarPerVector_K1,
211 false, // AThreadTransferSrcResetCoordinateAfterRun,
212 ABlockLdsAddExtraM,
213 BBlockTransferThreadClusterLengths_K0_N_K1,
214 BBlockTransferThreadClusterArrangeOrder,
215 BBlockTransferSrcAccessOrder,
216 BBlockTransferSrcVectorDim,
217 BBlockTransferSrcScalarPerVector,
218 BBlockTransferDstScalarPerVector_K1,
219 false, // BThreadTransferSrcResetCoordinateAfterRun,
220 BBlockLdsAddExtraN,
222 CThreadTransferSrcDstVectorDim,
223 CThreadTransferDstScalarPerVector,
224 NumGemmKPrefetchStage,
225 LoopSched,
226 PipelineVer>;
229
231
232 // Argument
233 struct Argument : public Problem, public BaseArgument
234 {
235 Argument(const ADataType* p_a_grid_,
236 const BDataType* p_b_grid_,
237 CDataType* p_c_grid_,
238 index_t M_,
239 index_t N_,
240 index_t K_,
241 index_t StrideA_,
242 index_t StrideB_,
243 index_t StrideC_,
244 index_t BatchStrideA,
245 index_t BatchStrideB,
246 index_t BatchStrideC,
247 index_t Batch_)
248 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
249 p_a_grid{p_a_grid_},
250 p_b_grid{p_b_grid_},
251 p_c_grid{p_c_grid_},
252 Batch(Batch_),
253 compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC}
254 {
255 }
256
257 const ADataType* p_a_grid;
258 const BDataType* p_b_grid;
259 CDataType* p_c_grid;
262 };
263
264 // Invoker
265 struct Invoker : public BaseInvoker
266 {
268
269 template <typename GridwiseGemm>
270 float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
271 {
272 if(stream_config.log_level_ > 0)
273 {
274 karg.Print();
275 }
276
277 typename GridwiseGemm::Problem arg(
278 karg.M, karg.N, karg.K, karg.StrideA, karg.StrideB, karg.StrideC);
279 if(!GridwiseGemm::CheckValidity(arg))
280 {
281 throw std::runtime_error(
282 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
283 }
284
285 auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
286 gdx *= karg.Batch;
287
288 float ave_time = 0;
289
290 if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
291 {
292 const auto kernel =
294
295 ave_time = launch_and_time_kernel(
296 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
297 }
298 else
299 {
300 const auto kernel =
302
303 ave_time = launch_and_time_kernel(
304 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
305 }
306
307 return ave_time;
308 }
309
311
312 // polymorphic
313 float Run(const BaseArgument* p_arg,
314 const StreamConfig& stream_config = StreamConfig{}) override
315 {
316 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
317 }
318 };
319
320 static constexpr bool IsValidCompilationParameter()
321 {
322 // TODO: properly implement this check
323 return true;
324 }
325
326 static bool IsSupportedArgument(const Problem& problem)
327 {
329 {
330 return false;
331 }
332 // temp disable on gfx11
334 {
335 return false;
336 }
337 if(get_warp_size() == 64)
338 {
339 if constexpr(NXdlPerWave64 > 0)
340 {
341 return GridwiseGemm64::CheckValidity(problem);
342 }
343 }
344 else
345 {
346 if constexpr(NXdlPerWave32 > 0)
347 {
349 reinterpret_cast<const typename GridwiseGemm32::Problem&>(problem));
350 }
351 }
352 return false;
353 }
354
355 // polymorphic
356 bool IsSupportedArgument(const BaseArgument* p_arg) override
357 {
358 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
359 }
360
361 static auto MakeArgument(const ADataType* p_a,
362 const BDataType* p_b,
363 CDataType* p_c,
364 index_t M,
365 index_t N,
366 index_t K,
367 index_t StrideA,
368 index_t StrideB,
369 index_t StrideC,
370 index_t BatchStrideA,
371 index_t BatchStrideB,
372 index_t BatchStrideC,
373 index_t Batch)
374 {
375 return Argument{p_a,
376 p_b,
377 p_c,
378 M,
379 N,
380 K,
381 StrideA,
382 StrideB,
383 StrideC,
384 BatchStrideA,
385 BatchStrideB,
386 BatchStrideC,
387 Batch};
388 }
389
390 static auto MakeInvoker() { return Invoker{}; }
391
392 // polymorphic
393 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
394 const void* p_b,
395 void* p_c,
396 index_t M,
397 index_t N,
398 index_t K,
399 index_t StrideA,
400 index_t StrideB,
401 index_t StrideC,
402 index_t BatchStrideA,
403 index_t BatchStrideB,
404 index_t BatchStrideC,
405 index_t Batch,
406 AElementwiseOperation,
407 BElementwiseOperation,
408 CElementwiseOperation) override
409 {
410 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
411 static_cast<const BDataType*>(p_b),
412 static_cast<CDataType*>(p_c),
413 M,
414 N,
415 K,
416 StrideA,
417 StrideB,
418 StrideC,
419 BatchStrideA,
420 BatchStrideB,
421 BatchStrideC,
422 Batch);
423 }
424
425 // polymorphic
426 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
427 {
428 return std::make_unique<Invoker>(Invoker{});
429 }
430
431 // polymorphic
432 std::string GetTypeString() const override
433 {
434 auto str = std::stringstream();
435
436 std::map<LoopScheduler, std::string> LoopSchedToString{
437 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
438
439 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
440 {PipelineVersion::v2, "v2"}};
441
442 // clang-format off
443 str << "DeviceBatchedGemmXdl"
444 << "<"
445 << BlockSize << ", "
446 << MPerBlock << ", "
447 << NPerBlock << ", "
448 << K0PerBlock << ", "
449 << K1 << ", "
450 << MPerXDL << ", "
451 << NPerXDL << ", "
452 << MXdlPerWave << ", "
453 << NXdlPerWave << ", "
454 << ">"
455 << " NumGemmKPrefetchStage: "
456 << NumGemmKPrefetchStage << ", "
457 << "LoopScheduler: "
458 << LoopSchedToString[LoopSched] << ", "
459 << "PipelineVersion: "
460 << PipelineVersionToString[PipelineVer];
461 // clang-format on
462
463 return str.str();
464 }
465};
466
467} // namespace device
468} // namespace tensor_operation
469} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
@ MNKPadding
Definition gemm_specialization.hpp:20
__global__ void kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
Definition device_batched_gemm_xdl.hpp:53
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
int64_t long_index_t
Definition ck.hpp:300
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:200
Definition gridwise_gemm_xdlops_v2r3.hpp:814
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_batched_gemm.hpp:25
Definition device_batched_gemm_xdl.hpp:234
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_xdl.hpp:261
const BDataType * p_b_grid
Definition device_batched_gemm_xdl.hpp:258
const ADataType * p_a_grid
Definition device_batched_gemm_xdl.hpp:257
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch_)
Definition device_batched_gemm_xdl.hpp:235
index_t Batch
Definition device_batched_gemm_xdl.hpp:260
CDataType * p_c_grid
Definition device_batched_gemm_xdl.hpp:259
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl.hpp:171
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
Definition device_batched_gemm_xdl.hpp:154
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl.hpp:161
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl.hpp:166
Definition device_batched_gemm_xdl.hpp:266
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_xdl.hpp:313
float RunImp(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_xdl.hpp:270
DeviceBatchedGemmXdl::Argument Argument
Definition device_batched_gemm_xdl.hpp:267
Definition device_batched_gemm_xdl.hpp:141
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_xdl.hpp:356
static constexpr auto I0
Definition device_batched_gemm_xdl.hpp:146
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_xdl.hpp:426
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_batched_gemm_xdl.hpp:393
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_xdl.hpp:143
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_xdl.hpp:227
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpecialization::MNKPadding, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumGemmKPrefetchStage, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_batched_gemm_xdl.hpp:184
static auto MakeInvoker()
Definition device_batched_gemm_xdl.hpp:390
std::string GetTypeString() const override
Definition device_batched_gemm_xdl.hpp:432
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_xdl.hpp:228
typename GridwiseGemm64::Problem Problem
Definition device_batched_gemm_xdl.hpp:230
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch)
Definition device_batched_gemm_xdl.hpp:361
static constexpr auto I1
Definition device_batched_gemm_xdl.hpp:147
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_xdl.hpp:144
static bool IsSupportedArgument(const Problem &problem)
Definition device_batched_gemm_xdl.hpp:326
static constexpr auto K1Number
Definition device_batched_gemm_xdl.hpp:150
static constexpr auto I2
Definition device_batched_gemm_xdl.hpp:148
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_xdl.hpp:320