generic_2d_block_shape.hpp Source File

generic_2d_block_shape.hpp Source File#

Composable Kernel: generic_2d_block_shape.hpp Source File
generic_2d_block_shape.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6namespace ck_tile {
7
8/*
9// clang-format off
10
114-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
12
13 Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
14 +<----------------------< Repeat_N(2)>--------------------->+
15 | |
16 +<-- <WarpPerBlock_N(2)> -->+
17 Warp_N
18 +--------------+--------------+--------------+--------------+----+----------------+
19 Warp_M | wrap_0 | wrap_1 | | ^ ^
20 +--------------+--------------+ | <WarpPerBlock_M(2)> |
21 | wrap_2 | wrap_3 | | v
22 +--------------+--------------+--------------+--------------+----+ Block_M
23 | | |
24 + + |
25 | | | v
26 +--------------+--------------+--------------+--------------+ +
27
28 each Warp-tile (e.g 16 thrd per row)
29
30 Vector_N (contiguous pixels each thrd holds along N, or vector size)
31 +-----------+-----------+-----------+-----------+-----------+
32 | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
33 +-----------+-----------+-----------+-----------+-----------+
34 | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
35 +-----------+-----------+-----------+-----------+-----------+
36// clang-format on
37*/
38template <typename BlockTile_, // block size, seq<M, N>
39 typename ThreadPerBlock_, // num threads along seq<M, N>
40 typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
42{
43 // block size
44 static constexpr index_t Block_M = BlockTile_::at(number<0>{});
45 static constexpr index_t Block_N = BlockTile_::at(number<1>{});
46 static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
47 static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
48
49 // vector size along seq<M, N>
50 static constexpr index_t Vector_M = Vector_::at(number<0>{});
51 static constexpr index_t Vector_N = Vector_::at(number<1>{});
52
53 // num warps along seq<M, N>, within each block
54 template <bool isHostWave32>
55 static constexpr index_t GetWarpPerBlock_M()
56 {
57 constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
58 constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
59 static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0);
60 constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size;
61
62 if constexpr(is_warp_per_row)
63 {
64 static_assert(warp_size % ThreadPerBlock_N == 0);
65 return total_warps * (warp_size / ThreadPerBlock_N);
66 }
67 else
68 {
69 // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
70 return total_warps / (ThreadPerBlock_N / warp_size);
71 }
72 };
73
74 // num of warps along n
75 template <bool isHostWave32>
76 static constexpr index_t GetWarpPerBlock_N()
77 {
78 constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
79 constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
80 if constexpr(is_warp_per_row)
81 {
82 static_assert(warp_size % ThreadPerBlock_N == 0);
83 return 1;
84 }
85 else
86 {
87 static_assert(ThreadPerBlock_N % warp_size == 0);
88 return ThreadPerBlock_N / warp_size;
89 }
90 }
91
94
95 // warp size
99 static_assert(Warp_M % Vector_M == 0);
100 static_assert(Warp_N % Vector_N == 0);
101 static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
102 static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
103
104 // repeat of each thread along seq<M, N>
105 static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
106 static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
107
108 // num of threads along seq<M, N>, within each warp
111
112 template <bool isHostWave32>
113 static constexpr index_t GetBlockSize()
114 {
115 constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
117 }
118};
119
120} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
Definition generic_2d_block_shape.hpp:42
static constexpr index_t GetWarpPerBlock_N()
Definition generic_2d_block_shape.hpp:76
static constexpr index_t Repeat_M
Definition generic_2d_block_shape.hpp:105
static constexpr index_t GetWarpPerBlock_M()
Definition generic_2d_block_shape.hpp:55
static constexpr index_t ThreadPerWarp_M
Definition generic_2d_block_shape.hpp:109
static constexpr index_t WarpPerBlock_N
Definition generic_2d_block_shape.hpp:93
static constexpr index_t GetBlockSize()
Definition generic_2d_block_shape.hpp:113
static constexpr index_t ThreadPerBlock_M
Definition generic_2d_block_shape.hpp:46
static constexpr index_t ThreadPerBlock_N
Definition generic_2d_block_shape.hpp:47
static constexpr index_t ThreadPerWarp_N
Definition generic_2d_block_shape.hpp:110
static constexpr index_t Block_M
Definition generic_2d_block_shape.hpp:44
static constexpr index_t Warp_N
Definition generic_2d_block_shape.hpp:98
static constexpr index_t BlockSize
Definition generic_2d_block_shape.hpp:96
static constexpr index_t Repeat_N
Definition generic_2d_block_shape.hpp:106
static constexpr index_t Block_N
Definition generic_2d_block_shape.hpp:45
static constexpr index_t Warp_M
Definition generic_2d_block_shape.hpp:97
static constexpr index_t WarpPerBlock_M
Definition generic_2d_block_shape.hpp:92
static constexpr index_t Vector_N
Definition generic_2d_block_shape.hpp:51
static constexpr index_t Vector_M
Definition generic_2d_block_shape.hpp:50