blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 1
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::KRepeat;
123 using Base::xdlops_gemm;
124
136
139
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
142
144
145 static constexpr index_t PrefetchStages = 1;
146 static constexpr index_t PrefillStages = 1;
147 static constexpr index_t GlobalBufferNum = 1;
148
149 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
150 {
151 return num_loop > PrefetchStages;
152 }
153
154 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
155 {
156 ignore = num_loop;
157 return TailNumber::Full;
158 }
159
160 template <bool HasMainLoop,
161 TailNumber TailNum,
162 typename AGridDesc,
163 typename ABlockDesc,
164 typename ABlockTransfer,
165 typename AGridBuffer,
166 typename ABlockBuffer,
167 typename ABlockTransferStep,
168 typename BGridDesc,
169 typename BBlockDesc,
170 typename BBlockTransfer,
171 typename BGridBuffer,
172 typename BBlockBuffer,
173 typename BBlockTransferStep,
174 typename CThreadBuffer,
175 // BScale Thread Copy
176 typename BScaleGridBuffer,
177 typename BScaleGridDesc,
178 typename BScaleThreadDesc,
179 typename BScaleThreadTransfer,
180 typename BScaleThreadTransferStep>
181 __device__ void Run(
182 // ABlockCopy
183 const AGridDesc& a_grid_desc,
184 const ABlockDesc& a_block_desc,
185 ABlockTransfer& a_blockwise_copy,
186 const AGridBuffer& a_grid_buf,
187 ABlockBuffer& a_block_buf,
188 const ABlockTransferStep& a_block_copy_step,
189 // BBlockCopy
190 const BGridDesc& b_grid_desc,
191 const BBlockDesc& b_block_desc,
192 BBlockTransfer& b_blockwise_copy,
193 const BGridBuffer& b_grid_buf,
194 BBlockBuffer& b_block_buf,
195 const BBlockTransferStep& b_block_copy_step,
196 // CThread
197 CThreadBuffer& c_thread_buf,
198 // BScaleThreadCopy
199 const BScaleGridDesc& b_scale_grid_desc,
200 const BScaleThreadDesc& b_scale_thread_desc,
201 BScaleThreadTransfer& b_scale_thread_copy,
202 const BScaleGridBuffer& b_scale_grid_buf,
203 const BScaleThreadTransferStep& b_scale_thread_copy_step,
204 // num_loop
205 index_t num_loop,
206 index_t num_loop_per_scale) const
207 {
208 // assume kperblock = scaleblockk
209 ignore = num_loop_per_scale;
211 a_thread_desc_.GetElementSpaceSize());
213 b_thread_desc_.GetElementSpaceSize());
214
216 b_scale_thread_desc.GetElementSpaceSize());
217
218 // Global prefetch 1
219 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
220 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
221
222 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
223 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
224
225 static_for<0, NRepeat, 1>{}([&](auto n0) {
226 b_scale_thread_copy.Run(b_scale_grid_desc,
227 b_scale_grid_buf,
228 b_scale_thread_desc,
229 make_tuple(n0, I0),
230 b_scale_thread_buf);
231
232 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
233 b_scale_thread_copy_step.At(Number<0>{}));
234 });
235 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
236 b_scale_thread_copy_step.At(Number<1>{}));
237
238 // Local prefill 1
239 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
240 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
241
242 // Initialize C
243 c_thread_buf.Clear();
244
245 auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
246
247 // main body
248 if constexpr(HasMainLoop)
249 {
250 index_t i = 0;
251 do
252 {
253 // -------------------------------------------------------------------------------------------
254 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
255 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
256
257 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
258 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
259
261 static_for<0, KRepeat, 1>{}([&](auto k) {
262 static_for<0, MRepeat, 1>{}([&](auto m0) {
265 a_block_buf,
267 make_tuple(m0, I0, k, I0),
268 a_thread_buf);
269 });
270 static_for<0, NRepeat, 1>{}([&](auto n0) {
273 b_block_buf,
275 make_tuple(n0, I0, k, I0),
276 b_thread_buf);
277 });
278 });
279
280 static_for<0, MRepeat, 1>{}([&](auto m0) {
281 static_for<0, NRepeat, 1>{}([&](auto n0) {
282 c_thread_buf_per_scale.Clear();
283 static_for<0, KRepeat, 1>{}([&](auto k0) {
286
287 static_for<0, KPack, 1>{}([&](auto ik) {
288 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
289 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
290 make_tuple(m0, I0, k0, ik))>{}];
291 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
292 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
293 make_tuple(n0, I0, k0, ik))>{}];
294 });
295
296 using mfma_input_type =
298 xdlops_gemm.K1PerXdlops>::type;
299
300 xdlops_gemm.template Run<>(
301 a_thread_vec.template AsType<mfma_input_type>(),
302 b_thread_vec.template AsType<mfma_input_type>(),
303 c_thread_buf_per_scale.GetVectorTypeReference(I0));
304 });
305 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
306 constexpr index_t c_offset =
307 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
308 c_thread_buf(Number<c_offset>{}) +=
309 c_thread_buf_per_scale[Number<t>{}] *
310 type_convert<AccDataType>(b_scale_thread_buf[n0]);
311 });
312 });
313 });
314
315 static_for<0, NRepeat, 1>{}([&](auto n0) {
316 b_scale_thread_copy.Run(b_scale_grid_desc,
317 b_scale_grid_buf,
318 b_scale_thread_desc,
319 make_tuple(n0, I0),
320 b_scale_thread_buf);
321
322 b_scale_thread_copy.MoveSrcSliceWindow(
323 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
324 });
325
326 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
327 b_scale_thread_copy_step.At(Number<1>{}));
328
330 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
331 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
332
333 i += 1;
334
335 } while(i < (num_loop - 1));
336 }
337
338 // tail
339 if constexpr(TailNum == TailNumber::Full)
340 {
342 static_for<0, KRepeat, 1>{}([&](auto k) {
343 static_for<0, MRepeat, 1>{}([&](auto m0) {
346 a_block_buf,
348 make_tuple(m0, I0, k, I0),
349 a_thread_buf);
350 });
351 static_for<0, NRepeat, 1>{}([&](auto n0) {
354 b_block_buf,
356 make_tuple(n0, I0, k, I0),
357 b_thread_buf);
358 });
359 });
360
361 static_for<0, MRepeat, 1>{}([&](auto m0) {
362 static_for<0, NRepeat, 1>{}([&](auto n0) {
363 c_thread_buf_per_scale.Clear();
364 static_for<0, KRepeat, 1>{}([&](auto k0) {
367
368 static_for<0, KPack, 1>{}([&](auto ik) {
369 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
370 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
371 make_tuple(m0, I0, k0, ik))>{}];
372 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
373 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
374 make_tuple(n0, I0, k0, ik))>{}];
375 });
376
377 using mfma_input_type =
378 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
379
380 xdlops_gemm.template Run<>(
381 a_thread_vec.template AsType<mfma_input_type>(),
382 b_thread_vec.template AsType<mfma_input_type>(),
383 c_thread_buf_per_scale.GetVectorTypeReference(I0));
384 });
385 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
386 constexpr index_t c_offset =
387 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
388 c_thread_buf(Number<c_offset>{}) +=
389 c_thread_buf_per_scale[Number<t>{}] *
390 type_convert<AccDataType>(b_scale_thread_buf[n0]);
391 });
392 });
393 });
394 }
395 }
396
397 protected:
398 using Base::a_thread_copy_;
399 using Base::a_thread_desc_;
400 using Base::b_thread_copy_;
401 using Base::b_thread_desc_;
402 using Base::c_thread_desc_;
403};
404
405} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:181
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10