gridwise_ab_transfer_wave_tiles.hpp Source File

gridwise_ab_transfer_wave_tiles.hpp Source File#

Composable Kernel: gridwise_ab_transfer_wave_tiles.hpp Source File
gridwise_ab_transfer_wave_tiles.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8#include "ck/utility/math.hpp"
9
10namespace ck {
11
12template <typename ABLayout,
13 typename ABMajorLayout,
14 typename LDSTypeAB,
15 index_t BlockSize,
16 index_t MNPerBlock,
17 index_t KPerBlock,
18 index_t MNPerWmma,
19 index_t KPack,
20 index_t ABK1Value,
21 index_t WaveSize>
23{
25 "wave tile transfer method does not support pk_i4_t");
26 static constexpr auto I0 = Number<0>{};
27 static constexpr auto I1 = Number<1>{};
28 static constexpr auto I2 = Number<2>{};
29 static constexpr auto I3 = Number<3>{};
30
31 static constexpr index_t MNKRow = 2;
32
34
35 // Tiles distribution for global memory loading
36 // Notes: support for not power of 2 needs to be reviewed later on
37 // The tiles are distributed along the non-contiguous matrix dimension
38 // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64
39 // MRepeat = 1, KRepeat = 4
40 // -------------
41 // |W0| | | |
42 // -------------
43 // |W1| | | |
44 // -------------
45 // |W2| | | |
46 // -------------
47 // |W3| | | |
48 // -------------
49 // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64
50 // MRepeat = 4, KRepeat = 1
51 // -------------
52 // |W0|W1|W2|W3|
53 // -------------
54 // | | | | |
55 // -------------
56 // | | | | |
57 // -------------
58 // | | | | |
59 // -------------
60 static constexpr index_t NumberOfWaves = BlockSize / WaveSize;
61 static constexpr index_t MNMajorWaves_ =
62 MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0
63 ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves)
64 : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1);
65 static constexpr index_t KMajorWaves_ =
66 KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0
67 ? std::min(KPerBlock / KPack, NumberOfWaves)
68 : (KPerBlock / KPack % 2 == 0 ? 2 : 1);
69
71
72 static constexpr index_t MNWaves_ =
75 static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
76 static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
77
78 template <bool PadMN, bool PadK, typename GridDescriptorBase>
79 __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
80 index_t sizeMN,
81 index_t,
82 index_t sizeK,
83 index_t,
84 index_t,
85 index_t)
86 {
87 // Notes: padding is currently not supported
88 static_assert(!PadMN && !PadK, "padding is currently not supported");
89
90 // Divide the base descriptor MN_K into tiles
91 const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
92 base_desc,
97 Number<KPack>{}))),
100
101 // The distinction is needed to get the same global indices for both layouts
102 // Divide each tile in 2 16x8 subtile
103 // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
104 // MNKRow = 0-1
105 // LaneLocal = 0-15
106 // VectorSize must be 8
107 if constexpr(!ABDoTranspose)
108 {
109 const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
111 ab_grid_desc_mntiles_ktiles,
118 make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
121
122 // Freeze VectorSize to first element of the loading chunk (for convenience)
123 // Swap MNPerWmma and MNKRow for consistency with transpose descriptor
125 ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
137 }
138 else
139 {
140 const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
142 ab_grid_desc_mntiles_ktiles,
148 make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
152
153 // Freeze VectorSize to first element of the loading chunk (for convenience)
155 ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
167 }
168 }
169
170 __device__ static constexpr auto GetBlockDescriptor()
171 {
172 // LDS memory layouts:
173 // lanes within tiles stored contiguously in chunks of 8 elements
174 // tiles are then stored first in K dimension
175 // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
176 const auto a_grid_desc_mraw_kraw = [&]() {
187 I1));
188 }();
189
190 // Freeze VectorSize to first element of the chunk (for convenience)
192 a_grid_desc_mraw_kraw,
200 }
201
202 __device__ static auto GetWaveIdx()
203 {
204 const index_t thread_id = ThisThreadBlock::GetThreadId();
205
206 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
210
211 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
212 }
213
214 __device__ static auto GetBlockLaneIdx()
215 {
216 const index_t lane_id = __lane_id();
217
218 constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma;
219
220 constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor(
224
225 return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
226 }
227
228 template <typename ABDataType>
229 __device__ static auto GetGridLaneIdx()
230 {
231 const index_t lane_id = __lane_id();
232
233 constexpr index_t SubTilesRow = MNKRow;
234 constexpr index_t SubTilesCol = 4 / sizeof(ABDataType);
235 constexpr index_t LanesPerSubTile =
236 ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol;
237 constexpr auto dims_tuple = ABDoTranspose
238 ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile)
239 : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile);
240
241 constexpr auto laneid_to_grid_lane_idx_adaptor =
245
246 const auto indices =
247 laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
248
249 if constexpr(!ABDoTranspose)
250 {
251 return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]);
252 }
253 else
254 {
255 return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]);
256 }
257 }
258
259 template <typename GridDescriptor,
260 typename BlockDescriptor,
261 typename ABsDataType,
262 typename ABElementwiseOperation,
263 index_t GlobalBufferNum>
264 __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
265 BlockDescriptor& block_descriptor,
266 ABElementwiseOperation& ab_element_op,
267 const index_t block_mn_id)
268 {
269 // Note: GlobalBufferNum is currently not used but it will be needed
270 // once we add other pipelines. It is currently needed only for
271 // consistency with the thread tiles approach
272 static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
273 constexpr index_t NumABTensor = ABsDataType::Size();
274 static_assert(NumABTensor == 1, "multiAB currently not supported");
275
277
278 const auto wave_idx = GetWaveIdx();
279 index_t wave_idK = wave_idx[I1];
280 index_t wave_idMN = wave_idx[I0];
281
282 const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
283 index_t lane_group_grid = grid_lane_id[I0];
284 index_t lane_local_id_grid = grid_lane_id[I1];
285
286 const auto block_lane_id = GetBlockLaneIdx();
287 index_t lane_group_block = block_lane_id[I0];
288 index_t lane_local_id_block = block_lane_id[I1];
289
290 return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
291 BlockDescriptor,
292 ABDataType,
293 ABDataType,
294 ABElementwiseOperation,
298 ABK1Value,
300 grid_descriptor[I0],
301 block_descriptor,
302 make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
303 wave_idK,
304 lane_group_grid,
305 lane_local_id_grid),
306 make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
307 ab_element_op);
308 }
309
310 template <index_t MNRepeat, index_t MNWaves>
311 __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
312 {
313 // This is a block descriptor used to read LDS memory into register
314 // It's defined in a way consistent with the existing implementation to
315 // avoid changes in the pipelines
327 I1));
328 }
329
330 __device__ static constexpr auto GetBlockStep()
331 {
332 // Grid descriptor step (MoveSrcSliceWindow)
334 }
335
336 template <typename GridDescriptor>
337 __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
338 {
339 return grid_desc.GetLength(I1) * KPack;
340 }
341};
342
343} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition gridwise_ab_transfer_wave_tiles.hpp:23
static __device__ auto GetWaveIdx()
Definition gridwise_ab_transfer_wave_tiles.hpp:202
__host__ static __device__ constexpr auto MakeWmmaTileDescriptor()
Definition gridwise_ab_transfer_wave_tiles.hpp:311
__host__ static __device__ auto MakeGridDescriptor(GridDescriptorBase &base_desc, index_t sizeMN, index_t, index_t sizeK, index_t, index_t, index_t)
Definition gridwise_ab_transfer_wave_tiles.hpp:79
static constexpr index_t MNRepeat_
Definition gridwise_ab_transfer_wave_tiles.hpp:76
static __device__ constexpr auto GetBlockDescriptor()
Definition gridwise_ab_transfer_wave_tiles.hpp:170
static __device__ auto GetGridLaneIdx()
Definition gridwise_ab_transfer_wave_tiles.hpp:229
static constexpr auto I2
Definition gridwise_ab_transfer_wave_tiles.hpp:28
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_ab_transfer_wave_tiles.hpp:33
static constexpr index_t KWaves_
Definition gridwise_ab_transfer_wave_tiles.hpp:74
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id)
Definition gridwise_ab_transfer_wave_tiles.hpp:264
static __device__ constexpr index_t GetKDimension(const GridDescriptor &grid_desc)
Definition gridwise_ab_transfer_wave_tiles.hpp:337
static constexpr index_t KMajorWaves_
Definition gridwise_ab_transfer_wave_tiles.hpp:65
static constexpr index_t MNMajorWaves_
Definition gridwise_ab_transfer_wave_tiles.hpp:61
static constexpr auto I1
Definition gridwise_ab_transfer_wave_tiles.hpp:27
static constexpr auto I3
Definition gridwise_ab_transfer_wave_tiles.hpp:29
static __device__ constexpr auto GetBlockStep()
Definition gridwise_ab_transfer_wave_tiles.hpp:330
static constexpr index_t MNKRow
Definition gridwise_ab_transfer_wave_tiles.hpp:31
static constexpr auto I0
Definition gridwise_ab_transfer_wave_tiles.hpp:26
static constexpr bool ABDoTranspose
Definition gridwise_ab_transfer_wave_tiles.hpp:70
static constexpr index_t MNWaves_
Definition gridwise_ab_transfer_wave_tiles.hpp:72
static constexpr index_t KRepeat_
Definition gridwise_ab_transfer_wave_tiles.hpp:75
static constexpr index_t NumberOfWaves
Definition gridwise_ab_transfer_wave_tiles.hpp:60
static __device__ auto GetBlockLaneIdx()
Definition gridwise_ab_transfer_wave_tiles.hpp:214
Definition utility/sequence.hpp:43
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition thread_group_tensor_slice_transfer_global.hpp:26
Definition data_type.hpp:187