static_encoding_pattern.hpp Source File

static_encoding_pattern.hpp Source File#

Composable Kernel: static_encoding_pattern.hpp Source File
static_encoding_pattern.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
70
71#pragma once
72
81
82namespace ck_tile {
83
106
110
123template <index_t BlockSize,
124 index_t YPerTile,
125 index_t XPerTile,
126 index_t VecSize,
127 tile_distribution_pattern DistributionPattern,
128 index_t NumWaveGroups = 1>
132
133// Thread raked
134template <index_t BlockSize,
135 index_t YPerTile,
136 index_t XPerTile,
137 index_t VecSize,
138 index_t NumWaveGroups>
140 YPerTile,
141 XPerTile,
142 VecSize,
144 NumWaveGroups>
146{
147 // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
148 static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
149 static constexpr index_t warp_size = get_warp_size();
150 static constexpr index_t num_warps = BlockSize / get_warp_size();
151 static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
152 static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
153 static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
154
155 // # of rows in Y dim accessed by single wavefront in one iteration
156 static constexpr index_t Y1 = warp_size / X0;
157 static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
158
159 static constexpr index_t Y0 = num_warps / NumWaveGroups;
160 // YPerWarp = YPerTile / Y0;
161 // Y2 = YPerWarp / Y1;
162 static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
163
164 static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
165 "X0 * warp_ys * Y0 must cover whole workgroup!");
166 static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
167
169 {
170 if constexpr(NumWaveGroups != 1)
171 {
176 tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
178 sequence<1, 1>>{}); // -> <Y2, X1>
179 }
180 else
181 {
186 tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
188 sequence<2, 1>>{}); // -> <Y2, X1>
189 }
190 }
191
193 {
194 if constexpr(NumWaveGroups != 1)
195 {
200 tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
202 sequence<1, 1>>{}); // -> <X1, Y2>
203 }
204 else
205 {
210 tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
212 sequence<1, 2>>{}); // -> <X1, Y2>
213 }
214 }
215};
216
217// Warp raked
218template <index_t BlockSize,
219 index_t YPerTile,
220 index_t XPerTile,
221 index_t VecSize,
222 index_t NumWaveGroups>
224 YPerTile,
225 XPerTile,
226 VecSize,
228 NumWaveGroups>
230{
231
232 static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
233 static constexpr index_t warp_size = get_warp_size();
234 static constexpr index_t num_warps = BlockSize / get_warp_size();
235 static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
236 static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
237 static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
238
239 static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
240 static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
241
242 static constexpr index_t Y0 = num_warps;
243 static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
244
245 static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
246 static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
247
258
269};
270
271// Block raked
272template <index_t BlockSize,
273 index_t YPerTile,
274 index_t XPerTile,
275 index_t VecSize,
276 index_t NumWaveGroups>
278 YPerTile,
279 XPerTile,
280 VecSize,
282 NumWaveGroups>
284{
285
286 // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
287 static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
288 static constexpr index_t warp_size = get_warp_size();
289 static constexpr index_t num_warps = BlockSize / get_warp_size();
290 static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
291 static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
292 static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
293 static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
294 static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
295 static constexpr index_t Y1 = num_warps;
296 static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
297 static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
298 static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
299
310
321};
322
323// Helper function to convert enum to string
325{
326 switch(pattern)
327 {
328 case tile_distribution_pattern::thread_raked: return "thread_raked";
329 case tile_distribution_pattern::warp_raked: return "warp_raked";
330 case tile_distribution_pattern::block_raked: return "block_raked";
331 default: return "unknown";
332 }
333}
334
335template <index_t BlockSize,
336 index_t YPerTile,
337 index_t XPerTile,
338 index_t VecSize,
339 tile_distribution_pattern DistributionPattern,
340 index_t NumWaveGroups>
342 YPerTile,
343 XPerTile,
344 VecSize,
345 DistributionPattern,
346 NumWaveGroups>&)
347{
348 using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
349 YPerTile,
350 XPerTile,
351 VecSize,
352 DistributionPattern,
353 NumWaveGroups>;
354
355 printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
356 "VecSize:%d, %s>: ",
357 BlockSize,
358 YPerTile,
359 XPerTile,
360 VecSize,
361 tile_distribution_pattern_to_string(DistributionPattern));
362 printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
363 PatternType::Y0,
364 PatternType::Y1,
365 PatternType::Y2,
366 PatternType::X0,
367 PatternType::X1);
368}
369
370} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
tile_distribution_pattern
Enumeration describing static tile distribution patterns.
Definition static_encoding_pattern.hpp:89
@ block_raked
Block raked pattern - aka linear.
Definition static_encoding_pattern.hpp:104
@ thread_raked
Thread raked pattern.
Definition static_encoding_pattern.hpp:94
@ warp_raked
Warp raked pattern.
Definition static_encoding_pattern.hpp:99
constexpr const char * tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
Definition static_encoding_pattern.hpp:324
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
Definition tile/core/container/sequence.hpp:49
static CK_TILE_HOST_DEVICE constexpr auto make_shuffled_2d_static_tile_distribution()
Definition static_encoding_pattern.hpp:192
static CK_TILE_HOST_DEVICE constexpr auto make_shuffled_2d_static_tile_distribution()
Definition static_encoding_pattern.hpp:259
static CK_TILE_HOST_DEVICE constexpr auto make_2d_static_tile_distribution()
Definition static_encoding_pattern.hpp:248
static CK_TILE_HOST_DEVICE constexpr auto make_shuffled_2d_static_tile_distribution()
Definition static_encoding_pattern.hpp:311
static CK_TILE_HOST_DEVICE constexpr auto make_2d_static_tile_distribution()
Definition static_encoding_pattern.hpp:300
Class creating 2D static tile distribution with different load/store patterns.
Definition static_encoding_pattern.hpp:130
Definition static_encoding_pattern.hpp:108
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192