device_grouped_gemm_xdl.hpp Source File

device_grouped_gemm_xdl.hpp Source File#

Composable Kernel: device_grouped_gemm_xdl.hpp Source File
device_grouped_gemm_xdl.hpp
Go to the documentation of this file.
1#pragma once
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6
7#include <iostream>
8#include <sstream>
9
11#include "ck/utility/env.hpp"
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26template <typename GridwiseGemm,
27 typename GemmDesc,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CDEElementwiseOperation,
31 bool HasMainKBlockLoop>
32__global__ void
33#if CK_USE_LAUNCH_BOUNDS
35#endif
36 kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
37 const index_t group_count,
38 const AElementwiseOperation a_element_op,
39 const BElementwiseOperation b_element_op,
40 const CDEElementwiseOperation c_element_op)
41{
42#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
44 {
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 const index_t block_id = get_block_1d_id();
48
49 const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
51
52 index_t left = 0;
53 index_t right = group_count;
54 index_t group_id = index_t((left + right) / 2);
55 while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
56 block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
57 left <= right)
58 {
59 if(block_id < gemm_desc_ptr[group_id].BlockStart_)
60 {
61 right = group_id;
62 }
63 else
64 {
65 left = group_id;
66 }
67 group_id = index_t((left + right) / 2);
68 }
69
70 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
71 gemm_desc_ptr[group_id].a_ptr_,
72 gemm_desc_ptr[group_id].b_ptr_,
73 gemm_desc_ptr[group_id].ds_ptr_,
74 gemm_desc_ptr[group_id].e_ptr_,
75 p_shared,
76 a_element_op,
77 b_element_op,
78 c_element_op,
79 gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
80 gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
81 gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
82 gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
83 gemm_desc_ptr[group_id].block_2_etile_map_);
84 }
85#else
86 ignore = gemm_descs_const;
87 ignore = group_count;
88 ignore = a_element_op;
89 ignore = b_element_op;
90 ignore = c_element_op;
91#endif
92}
93
94template <typename ALayout,
95 typename BLayout,
96 typename DsLayout,
97 typename ELayout,
98 typename ADataType,
99 typename BDataType,
100 typename AccDataType,
101 typename CShuffleDataType,
102 typename DsDataType,
103 typename EDataType,
104 typename AElementwiseOperation,
105 typename BElementwiseOperation,
106 typename CDEElementwiseOperation,
107 GemmSpecialization GemmSpec,
108 ck::index_t NumPrefetch,
109 ck::index_t BlockSize,
110 ck::index_t MPerBlock,
111 ck::index_t NPerBlock,
112 ck::index_t KPerBlock,
113 ck::index_t AK1,
114 ck::index_t BK1,
115 ck::index_t MPerXDL,
116 ck::index_t NPerXDL,
117 ck::index_t MXdlPerWave,
118 ck::index_t NXdlPerWave,
119 typename ABlockTransferThreadClusterLengths_K0_M_K1,
120 typename ABlockTransferThreadClusterArrangeOrder,
121 typename ABlockTransferSrcAccessOrder,
122 ck::index_t ABlockTransferSrcVectorDim,
123 ck::index_t ABlockTransferSrcScalarPerVector,
124 ck::index_t ABlockTransferDstScalarPerVector_K1,
125 bool ABlockLdsExtraM,
126 typename BBlockTransferThreadClusterLengths_K0_N_K1,
127 typename BBlockTransferThreadClusterArrangeOrder,
128 typename BBlockTransferSrcAccessOrder,
129 ck::index_t BBlockTransferSrcVectorDim,
130 ck::index_t BBlockTransferSrcScalarPerVector,
131 ck::index_t BBlockTransferDstScalarPerVector_K1,
132 bool BBlockLdsExtraN,
133 index_t CShuffleMXdlPerWavePerShuffle,
134 index_t CShuffleNXdlPerWavePerShuffle,
135 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136 index_t CDEBlockTransferScalarPerVector_NPerBlock,
138struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
139 BLayout,
140 DsLayout,
141 ELayout,
142 ADataType,
143 BDataType,
144 DsDataType,
145 EDataType,
146 AElementwiseOperation,
147 BElementwiseOperation,
148 CDEElementwiseOperation>
149{
150 using DeviceOp = DeviceGroupedGemm_Xdl;
152 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
153 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
154 static constexpr index_t NumDTensor = DsDataType::Size();
155
156 static constexpr auto I0 = Number<0>{};
157 static constexpr auto I1 = Number<1>{};
158 static constexpr auto I2 = Number<2>{};
159
160 static constexpr auto matrix_padder =
161 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
162
163 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
164 {
165 const auto a_grid_desc_mraw_kraw = [&]() {
166 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
167 {
168 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
169 make_tuple(StrideA, I1));
170 }
171 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
172 {
173 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
174 make_tuple(I1, StrideA));
175 }
176 }();
177
178 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
179 }
180
181 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
182 {
183 const auto b_grid_desc_nraw_kraw = [&]() {
184 if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
185 {
186 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
187 make_tuple(I1, StrideB));
188 }
189 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
190 {
191 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
192 make_tuple(StrideB, I1));
193 }
194 }();
195
196 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
197 }
198
199 template <typename ELay>
200 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
201 {
202 const auto e_grid_desc_mraw_nraw = [&]() {
203 if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
204 {
205 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
206 make_tuple(StrideE, I1));
207 }
208 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
209 {
210 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
211 make_tuple(I1, StrideE));
212 }
213 }();
214
215 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
216 }
217
218 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
219 const std::array<index_t, NumDTensor>& NRaws,
220 const std::array<index_t, NumDTensor>& DsStride)
221 {
222 return generate_tuple(
223 [&](auto i) {
224 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
225
226 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
227 },
228 Number<NumDTensor>{});
229 }
230
231 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
232 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
233 using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
234 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
235
236 using ComputeDataType = ADataType;
237
238 // GridwiseGemm
239 template <index_t NXdlPerWave_>
240 using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
241 ADataType, // TODO: distinguish A/B datatype
242 BDataType,
243 ComputeDataType,
244 AccDataType,
245 CShuffleDataType,
246 DsDataType,
247 EDataType,
248 AElementwiseOperation,
249 BElementwiseOperation,
250 CDEElementwiseOperation,
251 NumPrefetch, // NumGemmKPrefetchStage
252 BlockSize,
253 MPerBlock,
254 NPerBlock,
255 KPerBlock,
256 AK1,
257 BK1,
258 MPerXDL,
259 NPerXDL,
260 MXdlPerWave,
261 NXdlPerWave_,
262 ABlockTransferThreadClusterLengths_K0_M_K1,
263 ABlockTransferThreadClusterArrangeOrder,
264 ABlockTransferSrcAccessOrder,
265 ABlockTransferSrcVectorDim,
266 ABlockTransferSrcScalarPerVector,
267 ABlockTransferDstScalarPerVector_K1,
268 false, // AThreadTransferSrcResetCoordinateAfterRun,
269 ABlockLdsExtraM,
270 BBlockTransferThreadClusterLengths_K0_N_K1,
271 BBlockTransferThreadClusterArrangeOrder,
272 BBlockTransferSrcAccessOrder,
273 BBlockTransferSrcVectorDim,
274 BBlockTransferSrcScalarPerVector,
275 BBlockTransferDstScalarPerVector_K1,
276 false, // BThreadTransferSrcResetCoordinateAfterRun,
277 BBlockLdsExtraN,
278 CShuffleMXdlPerWavePerShuffle,
279 CShuffleNXdlPerWavePerShuffle,
280 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
281 CDEBlockTransferScalarPerVector_NPerBlock,
282 LoopSched>;
283 using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
284 using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
285
286 using AGridDesc_AK0_M_AK1 =
287 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(
288 AGridDesc_M_K{}))>;
289 using BGridDesc_BK0_N_BK1 =
290 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(
291 BGridDesc_N_K{}))>;
292 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
293 decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
294 DsGridDesc_M_N{}))>;
295 using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
296 decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
297 EGridDesc_M_N{}))>;
298
299 struct GroupedGemmBlock2ETileMap
300 {
301 using Block2ETileMap =
302 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
303
304 GroupedGemmBlock2ETileMap()
305 {
306 block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{});
307 BlockStart_ = -1;
308 }
309
310 GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
311 {
312 block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
313 BlockStart_ = BlockStart;
314 }
315
316 template <typename TopIdx>
317 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
318 {
319 return block_2_etile_map_.CalculateBottomIndex(
320 make_multi_index(idx_top[I0] - BlockStart_));
321 }
322
323 // it's actually E-Tile
324 template <typename CTileIdx, typename CTileDim>
325 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
326 const CTileDim& c_tile_dim) const
327 {
328 return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
329 }
330
331 __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
332 {
333 return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
334 }
335
336 Block2ETileMap block_2_etile_map_;
337 ck::index_t BlockStart_;
338 };
339
340 struct GemmBiasTransKernelArg
341 {
342 // pointers
343 const ADataType* a_ptr_;
344 const BDataType* b_ptr_;
345 typename GridwiseGemm64::DsGridPointer ds_ptr_;
346 EDataType* e_ptr_;
347
348 // tensor descriptors for problem definiton
349 AGridDesc_M_K a_grid_desc_m_k_;
350 BGridDesc_N_K b_grid_desc_n_k_;
351 DsGridDesc_M_N ds_grid_desc_m_n_;
352 EGridDesc_M_N e_grid_desc_m_n_;
353
354 // tensor descriptors for block/thread-wise copy
355 AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
356 BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
357 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
358 ds_grid_desc_mblock_mperblock_nblock_nperblock_;
359 EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
360
361 // block-to-e-tile map
362 GroupedGemmBlock2ETileMap block_2_etile_map_;
363 ck::index_t BlockStart_, BlockEnd_;
364 };
365
366 // Argument
367 struct Argument : public BaseArgument
368 {
369 template <typename GridwiseGemm, typename DsPointer, typename Block2ETileMap>
370 void init_gridwise_gemm_desc(const ADataType* a_ptr,
371 const BDataType* b_ptr,
372 DsPointer ds_ptr,
373 EDataType* e_ptr,
374 const AGridDesc_M_K& a_grid_desc_m_k,
375 const BGridDesc_N_K& b_grid_desc_n_k,
376 const DsGridDesc_M_N& ds_grid_desc_m_n,
377 const EGridDesc_M_N& e_grid_desc_m_n,
378 const Block2ETileMap& block_2_etile_map,
379 index_t BlockStart,
380 index_t BlockEnd)
381 {
382 // tensor descriptors for block/thread-wise copy
383 const auto a_grid_desc_ak0_m_ak1 =
384 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
385
386 const auto b_grid_desc_bk0_n_bk1 =
387 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
388
389 if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
390 b_grid_desc_n_k,
391 ds_grid_desc_m_n,
392 e_grid_desc_m_n,
393 block_2_etile_map))
394 {
395 // tensor descriptors for block/thread-wise copy
396 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
397 ds_grid_desc_mblock_mperblock_nblock_nperblock;
398
399 static_for<0, NumDTensor, 1>{}([&](auto j) {
400 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
401 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
402 ds_grid_desc_m_n[j]);
403 });
404
405 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
406 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
407 e_grid_desc_m_n);
408
409 gemm_desc_kernel_arg_.push_back(
410 GemmBiasTransKernelArg{a_ptr,
411 b_ptr,
412 ds_ptr,
413 e_ptr,
414 a_grid_desc_m_k,
415 b_grid_desc_n_k,
416 ds_grid_desc_m_n,
417 e_grid_desc_m_n,
418 a_grid_desc_ak0_m_ak1,
419 b_grid_desc_bk0_n_bk1,
420 ds_grid_desc_mblock_mperblock_nblock_nperblock,
421 e_grid_desc_mblock_mperblock_nblock_nperblock,
422 block_2_etile_map,
423 BlockStart,
424 BlockEnd});
425 }
426 };
427 Argument(std::vector<const void*>& p_As,
428 std::vector<const void*>& p_Bs,
429 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
430 std::vector<void*>& p_Es,
431 std::vector<GemmDesc>& gemm_descs,
432 AElementwiseOperation a_element_op,
433 BElementwiseOperation b_element_op,
434 CDEElementwiseOperation c_element_op)
435 : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
436 {
437 grid_size_ = 0;
438
439 group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
440
441 if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
442 group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
443 group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
444 {
445 throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
446 }
447
448 gemm_desc_kernel_arg_.reserve(group_count_);
449
450 skipped_group_count_ = 0;
451
452 for(std::size_t i = 0; i < gemm_descs.size(); i++)
453 {
454 const index_t M = gemm_descs[i].M_;
455 const index_t N = gemm_descs[i].N_;
456 const index_t K = gemm_descs[i].K_;
457
458 a_mtx_mraw_kraw_.emplace_back(M, K);
459 b_mtx_nraw_kraw_.emplace_back(N, K);
460
461 if(M == 0)
462 {
463 skipped_group_count_++;
464 continue;
465 }
466
467 const index_t StrideA = gemm_descs[i].stride_A_;
468 const index_t StrideB = gemm_descs[i].stride_B_;
469 const index_t StrideC = gemm_descs[i].stride_C_;
470
471 // pointer
472 typename GridwiseGemm64::DsGridPointer p_ds_grid{};
473
474 static_for<0, NumDTensor, 1>{}([&](auto j) {
475 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
476
477 p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
478 });
479
480 // tensor descriptors for problem definiton
481 const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
482 const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
483
484 DsGridDesc_M_N ds_grid_desc_m_n;
485
486 static_for<0, NumDTensor, 1>{}([&](auto j) {
487 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
488
489 ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
490 M, N, gemm_descs[i].stride_Ds_[j]);
491 });
492
493 const auto e_grid_desc_m_n =
494 DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
495
496 const index_t grid_size_grp =
497 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
498 .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
499
500 const index_t BlockStart = grid_size_;
501 const index_t BlockEnd = grid_size_ + grid_size_grp;
502
503 grid_size_ += grid_size_grp;
504
505 // block-to-e-tile map
506 const auto block_2_etile_map =
507 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
508
509 if(get_warp_size() == 64)
510 {
511 if constexpr(NXdlPerWave64 > 0)
512 {
513 init_gridwise_gemm_desc<GridwiseGemm64>(
514 static_cast<const ADataType*>(p_As[i]),
515 static_cast<const BDataType*>(p_Bs[i]),
516 p_ds_grid,
517 static_cast<EDataType*>(p_Es[i]),
518 a_grid_desc_m_k,
519 b_grid_desc_n_k,
520 ds_grid_desc_m_n,
521 e_grid_desc_m_n,
522 block_2_etile_map,
523 BlockStart,
524 BlockEnd);
525 }
526 }
527 else
528 {
529 if constexpr(NXdlPerWave32 > 0)
530 {
531 init_gridwise_gemm_desc<GridwiseGemm32>(
532 static_cast<const ADataType*>(p_As[i]),
533 static_cast<const BDataType*>(p_Bs[i]),
534 p_ds_grid,
535 static_cast<EDataType*>(p_Es[i]),
536 a_grid_desc_m_k,
537 b_grid_desc_n_k,
538 ds_grid_desc_m_n,
539 e_grid_desc_m_n,
540 block_2_etile_map,
541 BlockStart,
542 BlockEnd);
543 }
544 }
545 }
546 }
547
548 // private:
549 index_t group_count_;
550 index_t skipped_group_count_;
551
552 AElementwiseOperation a_element_op_;
553 BElementwiseOperation b_element_op_;
554 CDEElementwiseOperation c_element_op_;
555
556 std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
557 std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
558 std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
559
560 index_t grid_size_;
561 void* gemm_kernel_host_args_;
562 };
563
564 // Invoker
565 struct Invoker : public BaseInvoker
566 {
567 using Argument = DeviceOp::Argument;
568
569 template <typename GridwiseGemm>
570 float RunImp(const Argument& arg,
571 const StreamConfig& stream_config = StreamConfig{},
572 hipStream_t cpy_stream = nullptr,
573 hipEvent_t cpy_event = nullptr)
574 {
575 bool has_main_k_block_loop = true;
576
577 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
578 {
579 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
580 {
581 std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
582 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
583 << ", "
584 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
585 << ", "
586 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
587 << "}";
588
589 std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
590 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
591 << ", "
592 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
593 << ", "
594 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
595 << "}";
596
597 std::cout << ", arg.e_grid_desc_m_n_{ "
598 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
599 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
600 << std::endl;
601 }
602
603 if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
604 arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
605 arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
606 arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
607 arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
608 {
609 throw std::runtime_error(
610 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
611 }
612
613 const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
614 arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
615
616 if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
617 {
618 throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
619 }
620 }
621
622 // If the user provides copy stream and copy event, we assume that they're also
623 // responsible for providing allocated host memory (eg. pinned) which
624 // would be used to copy kernel arguments to the device.
625 if(cpy_stream && cpy_event)
626 {
627 if(arg.gemm_kernel_host_args_ == nullptr)
628 {
629 std::ostringstream err;
630 err << "No memory has been allocated for gemm kernel host args "
631 << "when providing the copy stream and copy event! In " << __FILE__ << ":"
632 << __LINE__ << ", in function: " << __func__;
633 throw std::runtime_error(err.str());
634 }
635 hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
636 arg.gemm_kernel_host_args_,
637 arg.group_count_ * sizeof(GemmBiasTransKernelArg),
638 hipMemcpyHostToDevice,
639 cpy_stream));
640 hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
641 hipGetErrorString(hipEventSynchronize(cpy_event));
642 }
643 else // In this case CK owns memory allocated on host.
644 {
645 hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
646 arg.gemm_desc_kernel_arg_.data(),
647 arg.gemm_desc_kernel_arg_.size() *
648 sizeof(GemmBiasTransKernelArg),
649 hipMemcpyHostToDevice,
650 stream_config.stream_id_));
651 }
652
653 float ave_time = 0;
654
655 auto launch_kernel = [&](auto has_main_k_block_loop_) {
656 const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
657 GemmBiasTransKernelArg,
658 AElementwiseOperation,
659 BElementwiseOperation,
660 CDEElementwiseOperation,
661 has_main_k_block_loop_>;
662
664 stream_config,
665 kernel,
666 dim3(arg.grid_size_),
667 dim3(BlockSize),
668 0,
670 arg.gemm_desc_kernel_arg_.size(),
671 arg.a_element_op_,
672 arg.b_element_op_,
673 arg.c_element_op_);
674 };
675
676 if(has_main_k_block_loop)
677 {
678 ave_time = launch_kernel(integral_constant<bool, true>{});
679 }
680 else
681 {
682 ave_time = launch_kernel(integral_constant<bool, false>{});
683 }
684
685 return ave_time;
686 }
687
688 float Run(const Argument& arg,
689 const StreamConfig& stream_config = StreamConfig{},
690 hipStream_t cpy_stream = nullptr,
691 hipEvent_t cpy_event = nullptr)
692 {
693 if(get_warp_size() == 64)
694 {
695 if constexpr(NXdlPerWave64 > 0)
696 {
697 return RunImp<GridwiseGemm64>(arg, stream_config, cpy_stream, cpy_event);
698 }
699 }
700 else
701 {
702 if constexpr(NXdlPerWave32 > 0)
703 {
704 return RunImp<GridwiseGemm32>(arg, stream_config, cpy_stream, cpy_event);
705 }
706 }
707 return 0;
708 }
709
710 // polymorphic
711 float Run(const BaseArgument* p_arg,
712 const StreamConfig& stream_config = StreamConfig{}) override
713 {
714 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
715 }
716 };
717
718 static bool IsSupportedArgument(const Argument& arg)
719 {
721 {
722 return false;
723 }
724 if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
725 arg.skipped_group_count_) != arg.group_count_)
726 {
727 return false;
728 }
729
730 bool supported = true;
731
732 // If we use padding we do not support vector loads for dimensions not divisible by
733 // vector load size.
734 if constexpr(GemmSpec != GemmSpecialization::Default)
735 {
736 // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
737 // layout, thus we have to adapt it to the {M,K} or {N,K} layout.
738 const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
739 const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
740
741 for(index_t i = 0; i < arg.group_count_; ++i)
742 {
743 const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
744 const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
745
746 supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
747 supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
748 }
749 }
750
751 return supported;
752 }
753
754 // polymorphic
755 bool IsSupportedArgument(const BaseArgument* p_arg) override
756 {
757 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
758 }
759
760 static auto MakeArgument(std::vector<const void*>& p_As,
761 std::vector<const void*>& p_Bs,
762 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
763 std::vector<void*>& p_Es,
764 std::vector<GemmDesc> gemm_descs,
765 AElementwiseOperation a_element_op,
766 BElementwiseOperation b_element_op,
767 CDEElementwiseOperation c_element_op)
768 {
769 return Argument{
770 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
771 }
772
773 static auto MakeInvoker() { return Invoker{}; }
774
775 // polymorphic
776 std::unique_ptr<BaseArgument>
777 MakeArgumentPointer(std::vector<const void*>& p_As,
778 std::vector<const void*>& p_Bs,
779 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
780 std::vector<void*>& p_Es,
781 std::vector<GemmDesc>& gemm_descs,
782 AElementwiseOperation a_element_op,
783 BElementwiseOperation b_element_op,
784 CDEElementwiseOperation c_element_op) override
785 {
786 return std::make_unique<Argument>(
787 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
788 }
789
790 // polymorphic
791 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
792 {
793 return std::make_unique<Invoker>(Invoker{});
794 }
795
796 // polymorphic
797 std::string GetTypeString() const override
798 {
799 auto str = std::stringstream();
800
801 // clang-format off
802 str << "DeviceGroupedGemm_Xdl"
803 << "<"
804 << BlockSize << ", "
805 << MPerBlock << ", "
806 << NPerBlock << ", "
807 << KPerBlock << ", "
808 << AK1 << ", "
809 << BK1 << ", "
810 << MPerXDL << ", "
811 << NPerXDL << ", "
812 << MXdlPerWave << ", "
813 << NXdlPerWave << ", "
814 << ABlockTransferSrcScalarPerVector << ", "
815 << BBlockTransferSrcScalarPerVector << ", "
816 << CShuffleMXdlPerWavePerShuffle << ", "
817 << CShuffleNXdlPerWavePerShuffle << ", "
818 << getGemmSpecializationString(GemmSpec)
819 << ">";
820 // clang-format on
821
822 return str.str();
823 }
824
825 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
826 {
827 auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
828 if(p_arg_)
829 {
830 return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
831 }
832 else
833 throw std::runtime_error("The argument pointer is not an object of "
834 "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
835 }
836
837 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
838 {
839 return GetWorkSpaceSize(p_arg);
840 }
841
842 void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
843 {
844 return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
845 }
846
847 size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
848
849 //----------------------------------------------------------------------------------------------
859 void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
860 {
861 Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
862 if(!pArg_)
863 {
864 throw std::runtime_error("Failed to cast argument pointer!");
865 }
866
867 pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
868 std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
869 pArg_->gemm_desc_kernel_arg_.end(),
870 static_cast<GemmBiasTransKernelArg*>(pArg_->gemm_kernel_host_args_));
871 }
872};
873
874} // namespace device
875} // namespace tensor_operation
876} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition device_grouped_gemm.hpp:99
Definition device_grouped_gemm.hpp:80
#define CK_ENV(name)
Definition utility/env.hpp:129