gridwise_gemm_xdlops_skip_b_lds_v1.hpp Source File

gridwise_gemm_xdlops_skip_b_lds_v1.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_skip_b_lds_v1.hpp Source File
gridwise_gemm_xdlops_skip_b_lds_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck {
18
19template <typename GridwiseGemm,
20 typename FloatAB,
21 typename FloatC,
22 typename AGridDesc_K0_M_K1,
23 typename BGridDesc_K0_N_K1,
24 typename CGridDesc_M_N,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename CElementwiseOperation,
28 typename Block2CTileMap,
29 bool HasMainK0BlockLoop>
30__global__ void
31#if CK_USE_LAUNCH_BOUNDS
33#endif
34 kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid,
35 const FloatAB* __restrict__ p_b_grid,
36 FloatC* __restrict__ p_c_grid,
37 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
38 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
39 const CGridDesc_M_N c_grid_desc_m_n,
40 const AElementwiseOperation a_element_op,
41 const BElementwiseOperation b_element_op,
42 const CElementwiseOperation c_element_op,
43 const Block2CTileMap block_2_ctile_map)
44{
45#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
46 defined(__gfx12__)
47 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
48 {
49 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50
51 auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
52 GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
53
54 auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
55 GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
56
57 GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
58 p_b_grid,
59 p_c_grid,
60 p_shared,
61 a_grid_desc_k0_m_k1,
62 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
63 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
64 a_element_op,
65 b_element_op,
66 c_element_op,
67 block_2_ctile_map);
68 }
69#else
70 ignore = p_a_grid;
71 ignore = p_b_grid;
72 ignore = p_c_grid;
73 ignore = a_grid_desc_k0_m_k1;
74 ignore = b_grid_desc_k0_n_k1;
75 ignore = c_grid_desc_m_n;
76 ignore = a_element_op;
77 ignore = b_element_op;
78 ignore = c_element_op;
79 ignore = block_2_ctile_map;
80#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
81}
82
83template <index_t BlockSize,
84 typename FloatAB,
85 typename FloatAcc,
86 typename FloatC,
87 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
88 typename AGridDesc_K0_M_K1,
89 typename BGridDesc_K0_N_K1,
90 typename CGridDesc_M_N,
91 typename AElementwiseOperation,
92 typename BElementwiseOperation,
93 typename CElementwiseOperation,
94 index_t MPerBlock,
95 index_t NPerBlock,
96 index_t K0PerBlock,
97 index_t MPerXdl,
98 index_t NPerXdl,
99 index_t K1Value,
100 index_t MXdlPerWave,
101 index_t NXdlPerWave,
102 typename ABlockTransferThreadClusterLengths_K0_M_K1,
103 typename ABlockTransferThreadClusterArrangeOrder,
104 typename ABlockTransferSrcAccessOrder,
105 index_t ABlockTransferSrcVectorDim,
106 index_t ABlockTransferSrcScalarPerVector,
107 index_t ABlockTransferDstScalarPerVector_K1,
108 bool AThreadTransferSrcResetCoordinateAfterRun,
109 bool ABlockLdsExtraM,
110 index_t BBlockTransferSrcScalarPerVector,
111 bool BThreadTransferSrcResetCoordinateAfterRun,
112 index_t BBlockBufferSize,
113 typename CThreadTransferSrcDstAccessOrder,
114 index_t CThreadTransferSrcDstVectorDim,
115 index_t CThreadTransferDstScalarPerVector>
117{
118 static constexpr auto I0 = Number<0>{};
119 static constexpr auto I1 = Number<1>{};
120 static constexpr auto I2 = Number<2>{};
121 static constexpr auto I3 = Number<3>{};
122 static constexpr auto I4 = Number<4>{};
123 static constexpr auto I5 = Number<5>{};
124 static constexpr auto I6 = Number<6>{};
125 static constexpr auto I7 = Number<7>{};
126
127 // K1 should be Number<...>
128 static constexpr auto K1 = Number<K1Value>{};
129
130 static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
131 static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
132 static constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
133
135 static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
136
138
139 __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
140 {
141 constexpr auto max_lds_align = K1;
142
143 // A matrix in LDS memory, dst of blockwise copy
144 constexpr auto a_block_desc_k0_m_k1 = [&]() {
145 if constexpr(ABlockLdsExtraM)
146 {
150 }
151 else
152 {
155 max_lds_align);
156 }
157 }();
158
159 return a_block_desc_k0_m_k1;
160 }
161
162 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
163 {
164 // LDS allocation for A and B: be careful of alignment
165 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
166
167 constexpr auto max_lds_align = K1;
168
169 constexpr auto a_block_space_size_aligned =
170 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
171
172 return (a_block_space_size_aligned) * sizeof(FloatAB);
173 }
174
175 template <
176 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
177 __device__ static bool constexpr IsValidCompilationParameter()
178 {
179 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
180 BlockSize,
181 MPerBlock,
182 NPerBlock,
183 MPerXdl,
184 NPerXdl,
185 MXdlPerWave,
186 NXdlPerWave,
187 FloatC,
188 CGlobalMemoryDataOperation>();
189 }
190 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
191 __host__ __device__ static constexpr bool
192 CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
193 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
194 const CGridDesc_M_N& c_grid_desc_m_n,
195 index_t M01,
196 index_t N01)
197 {
198 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
199 "wrong! K1 need to be known at compile-time");
200
201 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
202 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
203 "Invalid tuning param!");
204
205 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
206 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
207 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
208
209 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
210 K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
211 K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
212 return false;
213
214 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
215 return false;
216
217 // 2-stage prefetch currently only support even number of K0 loop
218 // TODO: add support for odd number of K0 loop
219 if(!((K0 / K0PerBlock) % BBlockBufferSize == 0))
220 {
221 return false;
222 }
223
224 // check M01, N01
225 constexpr auto M1 = Number<MPerBlock>{};
226 constexpr auto N1 = Number<NPerBlock>{};
227
228 const auto M0 = M / M1;
229 const auto N0 = N / N1;
230
231 if(!(M0 % M01 == 0 && N0 % N01 == 0))
232 return false;
233
234 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
235 return true;
236 }
237
238 __host__ __device__ static constexpr index_t
239 CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
240 {
241 const auto M = c_grid_desc_m_n.GetLength(I0);
242 const auto N = c_grid_desc_m_n.GetLength(I1);
243
244 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
245
246 return grid_size;
247 }
248
249 // TODO move this function into GEMM-pipeline class
250 __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
251 {
252 const bool has_main_k0_block_loop = (K0 / (BBlockBufferSize * K0PerBlock)) > 1;
253
254 return has_main_k0_block_loop;
255 }
256
257 __host__ __device__ static constexpr auto
258 MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
259 {
260 const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
261 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
262
263 const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
264 b_grid_desc_k0_n_k1,
266 make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
268 N / (NXdlPerWave * NWaves * NPerXdl), NXdlPerWave, NWaves, NPerXdl)),
272 return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
273 }
274
275 __device__ static auto GetWaveIdx()
276 {
277 const index_t thread_id = get_thread_local_1d_id();
278
279 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
283
284 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
285 }
286
287 __device__ static auto GetWaveKNIdx(const index_t thread_id)
288 {
289 constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
293
294 return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
295 }
296
297 __host__ __device__ static constexpr auto
298 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
299 {
300 constexpr auto max_lds_align = K1;
301
302 // A matrix in LDS memory, dst of blockwise copy
303 constexpr auto a_block_desc_k0_m_k1 = [&]() {
304 if constexpr(ABlockLdsExtraM)
305 {
309 }
310 else
311 {
313 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
314 }
315 }();
316
317 // B matrix threadwise copy
318 constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
320 I1,
321 Number<K0PerThread>{}, // K0PerThread
322 I1, // NBlockId
323 Number<NXdlPerWave>{}, // repeat
324 I1, // waves
325 I1, // NPerXdlops
326 Number<K1>{}));
327
329 BlockSize,
330 FloatAB,
331 FloatAcc,
332 decltype(a_block_desc_k0_m_k1),
333 decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
334 MPerBlock,
335 NPerBlock,
336 K0PerBlock,
337 MPerXdl,
338 NPerXdl,
339 MXdlPerWave,
340 NXdlPerWave,
341 K1>;
342
343 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
344 }
345
346 // return block_id to C matrix tile idx (m0, n0) mapping
347 __host__ __device__ static constexpr auto
348 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
349 {
350 const auto M = c_grid_desc_m_n.GetLength(I0);
351 const auto N = c_grid_desc_m_n.GetLength(I1);
352
353 constexpr auto M1 = Number<MPerBlock>{};
354 constexpr auto N1 = Number<NPerBlock>{};
355
356 const auto M0 = M / M1;
357 const auto N0 = N / N1;
358
359 const auto M00 = M0 / M01;
360 const auto N00 = N0 / N01;
361
362 const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
368
369 const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
371 make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
374
375 const auto cblockid_to_m0_n0_block_cluster_adaptor =
376 chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
377 cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
378
379 return cblockid_to_m0_n0_block_cluster_adaptor;
380 }
381
382 using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
383
384 template <bool HasMainK0BlockLoop,
385 typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
386 typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
387 typename Block2CTileMap = DefaultBlock2CTileMap>
388 __device__ static void
389 Run(const FloatAB* __restrict__ p_a_grid,
390 const FloatAB* __restrict__ p_b_grid,
391 FloatC* __restrict__ p_c_grid,
392 void* __restrict__ p_shared,
393 const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
394 const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
395 const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
396 const AElementwiseOperation& a_element_op,
397 const BElementwiseOperation& b_element_op,
398 const CElementwiseOperation& c_element_op,
399 const Block2CTileMap& block_2_ctile_map)
400 {
401 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
402 p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
403 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
404 p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize());
406 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
407
408 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
409
410 // divide block work by [M, N]
411 const auto block_work_idx =
412 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
413
414 // HACK: this force m/n_block_data_idx_on_grid into SGPR
415 const index_t m_block_data_idx_on_grid =
416 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
417
418 const index_t n_block_data_idx_on_grid =
419 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
420
421 // A matrix in LDS memory, dst of blockwise copy
422 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
423
424 // A matrix blockwise copy
425 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
427 AElementwiseOperation,
431 ABlockTransferThreadClusterLengths_K0_M_K1,
432 ABlockTransferThreadClusterArrangeOrder,
433 FloatAB,
434 FloatAB,
435 decltype(a_grid_desc_k0_m_k1),
436 decltype(a_block_desc_k0_m_k1),
437 ABlockTransferSrcAccessOrder,
439 ABlockTransferSrcVectorDim,
440 2,
441 ABlockTransferSrcScalarPerVector,
442 ABlockTransferDstScalarPerVector_K1,
443 1,
444 1,
445 AThreadTransferSrcResetCoordinateAfterRun,
446 true,
447 1>(a_grid_desc_k0_m_k1,
448 make_multi_index(0, m_block_data_idx_on_grid, 0),
449 a_element_op,
450 a_block_desc_k0_m_k1,
451 make_multi_index(0, 0, 0),
453
454 ignore = b_element_op;
455 // B matrix threadwise copy
456 constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
458 I1,
459 Number<K0PerThread>{}, // K0PerThread
460 I1, // NBlockId
461 Number<NXdlPerWave>{}, // repeat
462 I1, // waves
463 I1, // NPerXdlops
464 Number<K1>{}));
465
466 auto b_thread_buf = generate_tuple(
467 [&](auto i) {
468 ignore = i;
470 FloatAB,
471 b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
472 true>{};
473 },
475
476 const auto wave_id = GetWaveIdx();
477 const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
478
479#if 0
480 const index_t block_id = get_block_1d_id();
481 const index_t thread_id = get_thread_local_1d_id();
482 printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} "
483 "kn id: {%d %d}\n",
484 block_id,
485 block_work_idx[I0],
486 block_work_idx[I1],
487 thread_id,
488 wave_id[I0],
489 wave_id[I1],
490 wave_id[I2],
491 wave_k_n_id[I0],
492 wave_k_n_id[I1]);
493 printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
494 xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
495#endif
496
497 auto b_threadwise_copy =
499 FloatAB,
500 decltype(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
501 decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
502 Sequence<I1,
503 I1,
505 I1,
507 I1,
508 I1,
509 Number<K1>{}>,
511 7,
512 BBlockTransferSrcScalarPerVector,
513 BThreadTransferSrcResetCoordinateAfterRun,
514 true>(
515 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
517 0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
518
519 // GEMM definition
520 // c_mtx += transpose(a_mtx) * b_mtx
521 // a_mtx[K0PerBlock, MPerBlock] is in LDS
522 // b_mtx[K0PerBlock, NPerBlock] is in LDS
523 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
524 // register
525 // sanity check
527 BlockSize,
528 FloatAB,
529 FloatAcc,
530 decltype(a_block_desc_k0_m_k1),
531 decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
532 MPerBlock,
533 NPerBlock,
534 K0PerBlock,
535 MPerXdl,
536 NPerXdl,
537 MXdlPerWave,
538 NXdlPerWave,
539 K1>{};
540
541 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
542
543 // LDS allocation for A
545 static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
546
547 // gridwise GEMM pipeline
548 constexpr auto a_block_slice_copy_step =
549 make_multi_index(K0PerBlock * BBlockBufferSize, 0, 0);
550 constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
551 // preload data to regiester and LDS
552 {
553 // Read
554 a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
555 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
556
558 b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
559 b_grid_buf,
560 b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
561 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
562 b_thread_buf(Number<ii>{}));
563 b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
564 b_thread_slice_copy_step);
565 });
566
567 // Initialize C
568 c_thread_buf.Clear();
569 // a data write to lds
570 a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
571 // main body
572 if constexpr(HasMainK0BlockLoop)
573 {
574 index_t K0BlockMainLoop =
575 __builtin_amdgcn_readfirstlane(K0 / (BBlockBufferSize * K0PerBlock));
576 index_t i = 0;
577 do
578 {
579 a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
580 blockwise_gemm.ResetABlockStartWindow();
582
584 blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<ii>{}), c_thread_buf);
585 blockwise_gemm.MoveABlockSliceWindow();
586 s_nop();
587
588 b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
589 b_grid_buf,
590 b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
591 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
592 b_thread_buf(Number<ii>{}));
593 b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
594 b_thread_slice_copy_step);
595 });
596
598 a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
599 // move a and b window
600 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
601 a_block_slice_copy_step);
602
603 i += 1;
604 } while(i < (K0BlockMainLoop - 1));
605 }
606
607 // tail
608 {
610
611 blockwise_gemm.ResetABlockStartWindow();
612
614 blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<ii>{}), c_thread_buf);
615 blockwise_gemm.MoveABlockSliceWindow();
616 });
617 }
618 }
619
620 // output: register to global memory
621 {
622 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
623 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
624
625 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
626 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
627
628 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
629 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
630 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
631 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
632 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
633 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
634 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
635 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
636
637 // calculate origin of thread output tensor on global memory
638 // blockwise GEMM c matrix starting index
639 const auto c_thread_mtx_on_block =
640 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
641
642 const index_t m_thread_data_on_grid =
643 m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
644
645 const index_t n_thread_data_on_grid =
646 n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
647
648 const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
650 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
653
654 const auto m_thread_data_on_grid_idx =
655 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
656 make_multi_index(m_thread_data_on_grid));
657
658 const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
662
663 const auto n_thread_data_on_grid_idx =
664 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
665 make_multi_index(n_thread_data_on_grid));
666
667 auto c_thread_copy =
669 FloatC,
670 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
671 decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
672 CElementwiseOperation,
674 CThreadTransferSrcDstAccessOrder,
675 CThreadTransferSrcDstVectorDim,
676 CThreadTransferDstScalarPerVector,
677 CGlobalMemoryDataOperation,
678 1,
679 true>{
680 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
681 make_multi_index(m_thread_data_on_grid_idx[I0],
682 n_thread_data_on_grid_idx[I0],
683 m_thread_data_on_grid_idx[I1],
684 n_thread_data_on_grid_idx[I1],
685 m_thread_data_on_grid_idx[I2],
686 m_thread_data_on_grid_idx[I3],
687 m_thread_data_on_grid_idx[I4],
688 n_thread_data_on_grid_idx[I2]),
689 c_element_op};
690
691 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
692 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
693 c_thread_buf,
694 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
695 c_grid_buf);
696 }
697 }
698};
699
700} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__device__ void s_nop()
Definition synchronization.hpp:61
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdlops_skip_b_lds_v1.hpp:34
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:27
Definition gridwise_gemm_xdlops_skip_b_lds_v1.hpp:117
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 &b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 &c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdlops_skip_b_lds_v1.hpp:389
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition xdlops_gemm.hpp:1821
Definition is_known_at_compile_time.hpp:14
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340