gridwise_gemm_dl_multiple_d.hpp Source File

gridwise_gemm_dl_multiple_d.hpp Source File#

Composable Kernel: gridwise_gemm_dl_multiple_d.hpp Source File
gridwise_gemm_dl_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
21template <index_t BlockSize,
22 typename FloatAB,
23 typename FloatAcc,
24 typename DsDataType,
25 typename FloatC,
26 typename AElementwiseOperation,
27 typename BElementwiseOperation,
28 typename CDEElementwiseOperation,
29 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
30 typename AGridDesc_K0_M_K1,
31 typename BGridDesc_K0_N_K1,
32 typename CGridDesc_M_N,
33 index_t MPerBlock,
34 index_t NPerBlock,
35 index_t K0PerBlock,
36 index_t K1Value,
37 index_t M1PerThreadM111,
38 index_t N1PerThreadN111,
39 index_t KPerThread,
40 typename M11N11ThreadClusterM110Xs,
41 typename M11N11ThreadClusterN110Xs,
42 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
43 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
47 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
48 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
49 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
50 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
54 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
55 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
56 typename CThreadTransferSrcDstAccessOrder,
57 index_t CThreadTransferSrcDstVectorDim,
58 index_t CThreadTransferDstScalarPerVector>
60{
61 static constexpr index_t NumDTensor = DsDataType::Size();
62
63 static constexpr auto I0 = Number<0>{};
64 static constexpr auto I1 = Number<1>{};
65 static constexpr auto I2 = Number<2>{};
66 static constexpr auto I3 = Number<3>{};
67
68 // K1 should be Number<...>
69 static constexpr auto K1 = Number<K1Value>{};
70
71 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
72 static constexpr auto MakeDsGridPointer()
73 {
74 return generate_tuple(
75 [&](auto i) {
76 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
77
78 return static_cast<const DDataType*>(nullptr);
79 },
81 }
82
83 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
84 {
85 // TODO: change this. I think it needs multi-dimensional alignment
86 constexpr auto max_lds_align = K1;
87
88 // TODO: check alignment
89 // A matrix in LDS memory, dst of blockwise copy
90 constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
92
93 // TODO: check alignment
94 // B matrix in LDS memory, dst of blockwise copy
95 constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
97
98 // TODO: check alignment
99 // LDS allocation for A and B: be careful of alignment
100 constexpr auto a_block_aligned_space_size =
101 math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
102
103 constexpr auto b_block_aligned_space_size =
104 math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
105
106 return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
107 }
108
109 __host__ __device__ static constexpr bool
110 CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
111 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
112 const CGridDesc_M_N& c_grid_desc_m_n)
113 {
114 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
115
116 if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
117 b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
118 c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
119 {
120 return false;
121 }
122
123 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
124 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
125 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
126
127 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
128
129 return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
130 K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
131 K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
132 K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
133 (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
134 }
135
136 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
137 {
138 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
139
140 return grid_size;
141 }
142
143 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
144 {
145 const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
146
147 return has_main_k_block_loop;
148 }
149
150 __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
151 {
152 const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
153
154 return has_double_tail_k_block_loop;
155 }
156
157 __host__ __device__ static constexpr auto
158 MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
159 {
160 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
161 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
162
163 const auto M1 = Number<MPerBlock>{};
164 const auto M0 = M / M1;
165
166 const auto a_grid_desc_k0_m0_m1_k1 =
167 transform_tensor_descriptor(a_grid_desc_k0_m_k1,
173
174 return a_grid_desc_k0_m0_m1_k1;
175 }
176
177 __host__ __device__ static constexpr auto
178 MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
179 {
180 const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
181 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
182
183 const auto N1 = Number<NPerBlock>{};
184 const auto N0 = N / N1;
185
186 const auto b_grid_desc_k0_n0_n1_k1 =
187 transform_tensor_descriptor(b_grid_desc_k0_n_k1,
193
194 return b_grid_desc_k0_n0_n1_k1;
195 }
196
197 // E desc for destination in blockwise copy
198 template <typename CGridDesc_M_N_>
199 __host__ __device__ static constexpr auto
200 MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
201 {
202 const auto M = c_grid_desc_m_n.GetLength(I0);
203 const auto N = c_grid_desc_m_n.GetLength(I1);
204
205 constexpr auto M1 = Number<MPerBlock>{};
206 constexpr auto N1 = Number<NPerBlock>{};
207
208 const auto M0 = M / M1;
209 const auto N0 = N / N1;
210
211 constexpr auto M11 =
212 Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
213 M1PerThreadM111>{};
214 constexpr auto N11 =
215 Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
216 N1PerThreadN111>{};
217
218 constexpr auto M10 = M1 / M11;
219 constexpr auto N10 = N1 / N11;
220
221 const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
222 c_grid_desc_m_n,
224 make_unmerge_transform(make_tuple(N0, N10, N11))),
227
228 return c_grid_desc_m0_m10_m11_n0_n10_n11;
229 }
230
231 // Ds desc for source in blockwise copy
232 template <typename DsGridDesc_M_N>
233 __host__ __device__ static constexpr auto
234 MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N& ds_grid_desc_m_n)
235 {
236 return generate_tuple(
237 [&](auto i) { return MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n[i]); },
239 }
240 // return block_id to C matrix tile idx (m0, n0) mapping
241 __host__ __device__ static constexpr auto
242 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
243 {
245 c_grid_desc_m_n);
246 }
247
248 using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
249 using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
251 decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
252
253 using DsGridPointer = decltype(MakeDsGridPointer());
254
255 template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
256 bool HasMainKBlockLoop,
257 bool HasDoubleTailKBlockLoop,
258 typename Block2CTileMap>
259 __device__ static void
260 Run(const FloatAB* __restrict__ p_a_grid,
261 const FloatAB* __restrict__ p_b_grid,
262 DsGridPointer p_ds_grid,
263 FloatC* __restrict__ p_c_grid,
264 void* __restrict__ p_shared_block,
265 const AElementwiseOperation&,
266 const BElementwiseOperation&,
267 const CDEElementwiseOperation& cde_element_op,
268 const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
269 const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
270 const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11,
271 const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
272 const Block2CTileMap& block_2_ctile_map,
275 {
276 const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
277 p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
278 const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
279 p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
281 p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
282
283 // divide block work by [M, N]
284 const auto c_m0_n0_block_cluster_idx =
285 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
286
287 // HACK: this force index data into SGPR
288 const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
289 const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
290
291 if(!block_2_ctile_map.ValidCTileIndex(
292 make_tuple(im0, in0),
293 make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
294 c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
295 {
296 return;
297 }
298
299 // TODO: change this. I think it needs multi-dimensional alignment
300 constexpr auto max_lds_align = K1;
301
302 // TODO: check alignment
303 // A matrix in LDS memory, dst of blockwise copy
304 // be careful of LDS alignment
305 constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
306 make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
307
308 // TODO: check alignment
309 // B matrix in LDS memory, dst of blockwise copy
310 // be careful of LDS alignment
311 constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
312 make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
313
314 // TODO: check alignment
315 // A matrix in LDS memory, for blockwise GEMM
316 constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
317 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
318
319 // TODO: check alignment
320 // B matrix in LDS memory, for blockwise GEMM
321 constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
322 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
323
324 static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
325 a_k0_m_k1_block_desc.GetElementSpaceSize() &&
326 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
327 b_k0_n_k1_block_desc.GetElementSpaceSize() &&
328 "wrong!");
329
330 // A matrix blockwise copy
331 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
332 BlockSize,
334 Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
335 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
336 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
337 ABlockTransferThreadClusterArrangeOrder,
338 FloatAB,
339 FloatAB,
340 remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
341 decltype(a_block_desc_k0_m0_m1_k1),
342 ABlockTransferSrcAccessOrder,
344 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
345 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
346 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
347 Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
348 false,
349 true>(a_grid_desc_k0_m0_m1_k1,
350 make_multi_index(0, im0, 0, 0),
351 a_block_desc_k0_m0_m1_k1,
352 make_multi_index(0, 0, 0, 0));
353
354 // B matrix blockwise copy
355 auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
356 BlockSize,
358 Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
359 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
360 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
361 BBlockTransferThreadClusterArrangeOrder,
362 FloatAB,
363 FloatAB,
364 remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
365 decltype(b_block_desc_k0_n0_n1_k1),
366 BBlockTransferSrcAccessOrder,
368 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
369 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
370 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
371 Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
372 false,
373 true>(b_grid_desc_k0_n0_n1_k1,
374 make_multi_index(0, in0, 0, 0),
375 b_block_desc_k0_n0_n1_k1,
376 make_multi_index(0, 0, 0, 0));
377
378 // GEMM definition
379 // c_mtx += transpose(a_mtx) * b_mtx
380 // a_mtx[K0PerBlock, MPerBlock] is in LDS
381 // b_mtx[KPerBlocl, NPerBlock] is in LDS
382 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
383 // register
384 const auto blockwise_gemm =
386 BlockSize,
387 FloatAB,
388 FloatAB,
389 FloatAcc,
390 decltype(a_k0_m_k1_block_desc),
391 decltype(b_k0_n_k1_block_desc),
392 M1PerThreadM111,
393 N1PerThreadN111,
394 KPerThread,
395 M11N11ThreadClusterM110Xs,
396 M11N11ThreadClusterN110Xs,
397 M1PerThreadM111,
398 N1PerThreadN111>{};
399
400 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
401 decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
402
403 constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
404 sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
405
406 // LDS allocation for A and B: be careful of alignment
407 constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
408 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
409
410 constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
411 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
412
413 FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
414 FloatAB* p_b_block_double =
415 static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
416
417 // register allocation for output
419 c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
420
421 // Initialize C
422 c_thread_buf.Clear();
423
424 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
425 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
426
427 auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
428 p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
429 auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
430 p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
431
432 auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
433 p_a_block_double + a_block_aligned_space_size,
434 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
435 auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
436 p_b_block_double + b_block_aligned_space_size,
437 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
438
439 // LDS double buffer: preload data into LDS
440 {
441 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
442 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
443
444 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
445 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
446 }
447
448 if constexpr(HasMainKBlockLoop)
449 {
450 const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
451
452 index_t k_block_data_begin = 0;
453
454 // LDS double buffer: main body
455 // use Do-While loop instead of For loop to simplify control flow
456 do
457 {
458 // even iteration
459 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
460 a_block_slice_copy_step);
461 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
462 b_block_slice_copy_step);
463
464 // LDS doubel buffer: load next data from device mem
465 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
466 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
467
469
470 // LDS double buffer: GEMM on current data
471 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
472 a_block_even_buf,
473 b_block_even_buf,
474 c_thread_buf);
475
476 // LDS double buffer: store next data to LDS
477 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
478 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
479
480 // odd iteration
481 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
482 a_block_slice_copy_step);
483 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
484 b_block_slice_copy_step);
485
486 // LDS doubel buffer: load next data from device mem
487 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
488 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
489
491
492 // LDS double buffer: GEMM on current data
493 blockwise_gemm.Run(
494 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
495
496 // LDS double buffer: store next data to LDS
497 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
498 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
499
500 k_block_data_begin += 2 * K0PerBlock;
501 } while(k_block_data_begin < K0 - 2 * K0PerBlock);
502 }
503
504 // LDS double buffer: tail
505 if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
506 {
507 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
508 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
509
511
512 // LDS double buffer: load last data from device mem
513 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
514 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
515
516 // LDS double buffer: GEMM on 2nd-last data
517 blockwise_gemm.Run(
518 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
519
520 // LDS double buffer: store last data to LDS
521 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
522 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
523
525
526 // LDS double buffer: GEMM on last data
527 blockwise_gemm.Run(
528 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
529 }
530 else // if has 1 iteration left
531 {
532 __syncthreads();
533
534 // LDS double buffer: GEMM on last data
535 blockwise_gemm.Run(
536 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
537 }
538
539 // output: register to global memory
540 {
541 constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
544 Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
546 I1,
549
550 const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
551 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
553
554 const auto ds_grid_buf = generate_tuple(
555 [&](auto i) {
557 p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
558 },
560
561 auto ds_thread_buf = generate_tuple(
562 [&](auto i) {
563 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
564
566 DDataType,
567 c_m10_m11_n10_n11_thread_tensor_lengths[I3],
568 true>{};
569 },
571
572 auto ds_threadwise_copy = generate_tuple(
573 [&](auto i) {
574 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
575
577 DDataType,
578 DDataType,
579 decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
580 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
581 Sequence<I1,
582 I1,
583 I1,
584 I1,
585 I1,
587 CThreadTransferSrcDstAccessOrder,
588 CThreadTransferSrcDstVectorDim,
589 CThreadTransferDstScalarPerVector,
590 1,
591 false>(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
593 c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
594 c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
595 in0,
596 c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
597 c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]));
598 },
600
604 // load d matrix data
605 static_for<0, NumDTensor, 1>{}([&](auto i) {
606 ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
607 ds_grid_buf[i],
608 c_thread_desc_m0_m10_m11_n0_n10_n11,
609 make_tuple(I0, I0, I0, I0, I0, I0),
610 ds_thread_buf(i));
611 });
612 // cal element op
614 [&](auto i) {
615 // get reference to src data
616 const auto src_data_refs = generate_tie(
617 // return type should be lvalue
618 [&](auto iSrc) -> const auto& {
619 return ds_thread_buf[iSrc][i];
620 },
622
623 // get reference to dst data
624 constexpr index_t c_offset =
625 c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset(
626 make_tuple(0, m10, m11, 0, n10, i));
627 auto dst_data_refs = generate_tie(
628 // return type should be lvalue
629 [&](auto) -> auto& { return c_thread_buf(Number<c_offset>{}); },
630 Number<2>{});
631
632 unpack2(cde_element_op, dst_data_refs, src_data_refs);
633 });
634
635 static_for<0, NumDTensor, 1>{}([&](auto i) {
636 ds_threadwise_copy(i).MoveSrcSliceWindow(
637 ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
638 make_multi_index(0, 0, 0, 0, 1, 0));
639 });
640 });
641 static_for<0, NumDTensor, 1>{}([&](auto i) {
642 ds_threadwise_copy(i).MoveSrcSliceWindow(
643 ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
645 0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0));
646 });
647 });
648 static_for<0, NumDTensor, 1>{}([&](auto i) {
649 ds_threadwise_copy(i).MoveSrcSliceWindow(
650 ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
652 0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
653 });
654 });
655
657 FloatAcc,
658 FloatC,
659 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
660 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
662 Sequence<1,
663 c_m10_m11_n10_n11_thread_tensor_lengths[I0],
664 c_m10_m11_n10_n11_thread_tensor_lengths[I1],
665 1,
666 c_m10_m11_n10_n11_thread_tensor_lengths[I2],
667 c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
668 CThreadTransferSrcDstAccessOrder,
669 CThreadTransferSrcDstVectorDim,
670 CThreadTransferDstScalarPerVector,
671 CGlobalMemoryDataOperation,
672 1,
673 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
675 c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
676 c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
677 in0,
678 c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
679 c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
681 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
682 make_tuple(I0, I0, I0, I0, I0, I0),
683 c_thread_buf,
684 c_grid_desc_m0_m10_m11_n0_n10_n11,
685 c_grid_buf);
686 }
687 }
688};
689
690} // namespace ck
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Definition block_to_ctile_map.hpp:617
Definition blockwise_tensor_slice_transfer_v5r1.hpp:37
Definition gridwise_gemm_dl_multiple_d.hpp:60
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:242
__host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:178
__host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:158
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_multiple_d.hpp:136
__host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:150
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:143
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:110
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_dl_multiple_d.hpp:83
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, DsGridPointer p_ds_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared_block, const AElementwiseOperation &, const BElementwiseOperation &, const CDEElementwiseOperation &cde_element_op, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 &ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition gridwise_gemm_dl_multiple_d.hpp:260
__host__ static __device__ constexpr auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:234
static constexpr auto MakeDsGridPointer()
Definition gridwise_gemm_dl_multiple_d.hpp:72
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:200
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/integral_constant.hpp:20
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340