block_gemm_areg_bsmem_creg_v1.hpp Source File

block_gemm_areg_bsmem_creg_v1.hpp Source File#

Composable Kernel: block_gemm_areg_bsmem_creg_v1.hpp Source File
block_gemm_areg_bsmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block distributed tensor
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
16{
23
24 static constexpr index_t kBlockSize = Problem::kBlockSize;
25
26 // C += A * B
27 template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
28 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
29 const ABlockTensorTmp& a_block_tensor_tmp,
30 const BBlockWindowTmp& b_block_window_tmp) const
31 {
32 static_assert(
33 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
34 std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
35 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
36 "wrong!");
37
38 constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
39 constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
40 constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
41
42 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
43 KPerBlock == BlockGemmShape::kK,
44 "wrong!");
45
46 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
47
48 using WG = remove_cvref_t<decltype(config.template at<0>())>;
49
50 constexpr index_t MWarp = config.template at<1>();
51 constexpr index_t NWarp = config.template at<2>();
52
53 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
54 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
55 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
56
57 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
58 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
59
60 const index_t iNWarp = get_warp_id() % NWarp;
61
62 constexpr auto a_block_outer_dstr_encoding =
69
70 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
77
78 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
79 a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
80
81 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
82 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
83
84 constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
85
86 // constrcut from A-block-tensor from A-Block-tensor-tmp
87 // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
88 // distribution
89 auto a_block_tensor =
91
92 a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
93
94 // construct B-warp-window
95 auto b_warp_window_tmp = make_tile_window(
96 b_block_window_tmp.get_bottom_tensor_view(),
98 b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
99 make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
100
101#if 0 // FIXME: using array will cause register spill
102 array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
103 {b_warp_window_tmp}};
104
105 for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
106 {
107 for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
108 {
109 move_tile_window(b_warp_windows(nIter)(kIter),
110 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
111 }
112 }
113#else
115 statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
116 NIterPerWarp>
117 b_warp_windows;
118
119 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
120 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
121 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
122
123 move_tile_window(b_warp_windows(nIter)(kIter),
124 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
125 });
126 });
127#endif
128
129 // check C-block-distribution
130 static_assert(
131 std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
132 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
133 .get_static_tile_distribution_encoding())>>,
134 "wrong!");
135
136 using AWarpDstr = typename WG::AWarpDstr;
137 using CWarpDstr = typename WG::CWarpDstr;
138
139 using AWarpTensor = typename WG::AWarpTensor;
140 using CWarpTensor = typename WG::CWarpTensor;
141
142 constexpr auto a_warp_y_lengths =
143 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
144 constexpr auto c_warp_y_lengths =
145 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
146
147 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
148 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
149
150 // hot loop:
151 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
152 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
153 // read A warp tensor from A block tensor
154 AWarpTensor a_warp_tensor;
155
156 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
157 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
158 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
159
160 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
161 // read B warp tensor from B Block window
162 const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
163
164 // read C warp tensor from C block tensor
165 CWarpTensor c_warp_tensor;
166
167 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
168 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
169 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
170
171 // warp GEMM
172 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
173
174 // write C warp tensor into C block tensor
175 c_block_tensor.set_y_sliced_thread_data(
176 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
177 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
178 c_warp_tensor.get_thread_buffer());
179 });
180 });
181 });
182 }
183
184 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
185 {
186 constexpr index_t MPerBlock = BlockGemmShape::kM;
187 constexpr index_t NPerBlock = BlockGemmShape::kN;
188
189 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
190
191 using WG = remove_cvref_t<decltype(config.template at<0>())>;
192
193 constexpr index_t MWarp = config.template at<1>();
194 constexpr index_t NWarp = config.template at<2>();
195
196 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
197 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
198 // constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
199
200 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
207
208 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
209 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
210 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
211 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
212 return c_block_tensor;
213 }
214
215 // C = A * B
216 template <typename ABlockTensorTmp, typename BBlockWindowTmp>
217 CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
218 const BBlockWindowTmp& b_block_window_tmp) const
219 {
220 auto c_block_tensor = MakeCBlockTile();
221 operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
222 return c_block_tensor;
223 }
224};
225
226} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_areg_bsmem_creg_v1.hpp:16
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_areg_bsmem_creg_v1.hpp:20
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_areg_bsmem_creg_v1.hpp:19
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_v1.hpp:217
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_v1.hpp:28
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_areg_bsmem_creg_v1.hpp:21
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_bsmem_creg_v1.hpp:17
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_bsmem_creg_v1.hpp:18
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_bsmem_creg_v1.hpp:184
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_areg_bsmem_creg_v1.hpp:22
static constexpr index_t kBlockSize
Definition block_gemm_areg_bsmem_creg_v1.hpp:24
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192