blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
63 BlockSize,
64 ADataType,
65 BDataType,
66 ComputeDataType,
67 AccDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::KRepeat;
126 using Base::xdlops_gemm;
127 using typename Base::HotLoopInstList;
128
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
145 using Base::WaveSize;
146
147 static constexpr index_t PrefetchStages = 2;
148 static constexpr index_t PrefillStages = 1;
149 static constexpr index_t GlobalBufferNum = 2;
150
151 template <typename TileDesc_M0_M1_M2_K>
152 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
153 {
154 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
155 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
156 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
157 constexpr index_t K2 = KPack;
158 constexpr index_t K1 = WaveSize / NPerXDL;
159 constexpr index_t K0 = KRepeat;
160
162 TileDesc_M0_M1_M2_K{},
170 }
171
172 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
174
175 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
176 {
177 return num_loop > PrefetchStages;
178 }
179
180 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
181 {
182 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
183 }
184
185 __device__ static constexpr auto HotLoopScheduler()
186 {
187 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
188 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
189 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
190
191 // B global
193 ignore = i;
194 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
195 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
196 });
197
198 // A global
200 ignore = i;
201 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
202 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
203 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
204 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
205 });
206
207 // A local
208 static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
209 ignore = i;
210 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
211 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
212 });
213 }
214
215 template <bool HasMainLoop,
216 TailNumber TailNum,
217 typename AGridDesc,
218 typename ABlockDesc,
219 typename ABlockTransfer,
220 typename AGridBuffer,
221 typename ABlockBuffer,
222 typename ABlockTransferStep,
223 typename BGridDesc,
224 typename BBlockTransfer,
225 typename BGridBuffer,
226 typename BBlockBuffer,
227 typename BBlockTransferStep,
228 typename CThreadBuffer>
229 __device__ void Run(const AGridDesc& a_grid_desc,
230 const ABlockDesc& a_block_desc,
231 ABlockTransfer& a_blockwise_copy,
232 const AGridBuffer& a_grid_buf,
233 ABlockBuffer& a_block_buf,
234 const ABlockTransferStep& a_block_copy_step,
235 const BGridDesc& b_grid_desc,
236 BBlockTransfer& b_blockwise_copy,
237 BBlockTransfer& b_blockwise_copy_up,
238 const BGridBuffer& b_grid_buf,
239 const BGridBuffer& b_grid_buf_up,
240 BBlockBuffer& b_block_buf,
241 const BBlockTransferStep& b_block_copy_step,
242 CThreadBuffer& c_thread_buf,
243 CThreadBuffer& c_thread_buf_up,
244 index_t num_loop) const
245
246 {
247 ignore = b_block_buf;
248 __builtin_amdgcn_sched_barrier(0);
250 a_thread_desc_.GetElementSpaceSize());
252 b_thread_desc_.GetElementSpaceSize());
253
255 b_thread_desc_.GetElementSpaceSize());
256
257 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
258 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
259 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
260
261 StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
262 StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}>
263 b_thread_dequant_bufs_up;
264
265 // Global prefetch A1 B1
266 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
267 b_blockwise_copy.Run(b_grid_desc,
268 b_grid_buf,
270 b_block_origin_idx,
271 b_thread_bufs(I0));
272 b_blockwise_copy_up.Run(b_grid_desc,
273 b_grid_buf_up,
275 b_block_origin_idx,
276 b_thread_bufs_up(I0));
277
278 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
279 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
280 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
281 __builtin_amdgcn_sched_barrier(0);
282
283 // // Local prefill A1
284 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
285
286 // // Global prefetch A2
287 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
288 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
289
290 // Local prefetch A1
292 static_for<0, MRepeat, 1>{}([&](auto m0) {
293 static_for<0, KRepeat, 1>{}([&](auto k0) {
295 make_tuple(m0, I0, I0, k0, I0, I0),
296 a_block_buf,
298 make_tuple(m0, I0, I0, k0, I0, I0),
299 a_thread_buf);
300 });
301 });
302 // B VGPR->VGPR dequant
304 b_block_origin_idx,
305 b_thread_bufs(I0),
307 make_tuple(I0, I0, I0, I0),
308 b_thread_dequant_bufs(I0));
310 b_block_origin_idx,
311 b_thread_bufs_up(I0),
313 make_tuple(I0, I0, I0, I0),
314 b_thread_dequant_bufs_up(I0));
315
316 // Initialize C
317 c_thread_buf.Clear();
318 c_thread_buf_up.Clear();
319
320 __builtin_amdgcn_sched_barrier(0);
321
322 // main body
323 if constexpr(HasMainLoop)
324 {
325 index_t i = 0;
326 do
327 {
328 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
329 b_blockwise_copy.Run(b_grid_desc,
330 b_grid_buf,
332 b_block_origin_idx,
333 b_thread_bufs(local_read_buf));
334 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
335 b_blockwise_copy_up.Run(b_grid_desc,
336 b_grid_buf_up,
338 b_block_origin_idx,
339 b_thread_bufs_up(local_read_buf));
340 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
341
343 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
344
345 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
346 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
347
348 static_for<0, MRepeat, 1>{}([&](auto m0) {
349 static_for<0, NRepeat, 1>{}([&](auto n0) {
350 static_for<0, KRepeat, 1>{}([&](auto k0) {
354
355 static_for<0, KPack, 1>{}([&](auto ik) {
356 a_thread_vec.template AsType<ComputeDataType>()(ik) =
357 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
358 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
359 b_thread_vec.template AsType<ComputeDataType>()(ik) =
360 b_thread_dequant_bufs[mfma_reg_buf]
361 [Number<b_thread_desc_.CalculateOffset(
362 make_tuple(n0, I0, k0, ik))>{}];
363 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
364 b_thread_dequant_bufs_up
365 [mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
366 make_tuple(n0, I0, k0, ik))>{}];
367 });
368 using mfma_input_type =
369 typename vector_type<ComputeDataType,
370 xdlops_gemm.K1PerXdlops>::type;
371
372 constexpr index_t c_offset =
373 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
374
375 xdlops_gemm.Run(
376 a_thread_vec.template AsType<mfma_input_type>(),
377 b_thread_vec.template AsType<mfma_input_type>(),
378 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
379
380 xdlops_gemm.Run(
381 a_thread_vec.template AsType<mfma_input_type>(),
382 b_thread_vec_up.template AsType<mfma_input_type>(),
383 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
384 });
385 });
386 });
387
389
390 static_for<0, MRepeat, 1>{}([&](auto m0) {
391 static_for<0, KRepeat, 1>{}([&](auto k0) {
393 make_tuple(m0, I0, I0, k0, I0, I0),
394 a_block_buf,
396 make_tuple(m0, I0, I0, k0, I0, I0),
397 a_thread_buf);
398 });
399 });
400 // B VGPR->VGPR dequant
402 b_block_origin_idx,
403 b_thread_bufs(local_read_buf),
405 make_tuple(I0, I0, I0, I0),
406 b_thread_dequant_bufs(local_read_buf));
408 b_block_origin_idx,
409 b_thread_bufs_up(local_read_buf),
411 make_tuple(I0, I0, I0, I0),
412 b_thread_dequant_bufs_up(local_read_buf));
413
415 __builtin_amdgcn_sched_barrier(0);
416 };
417
418 LoopFunc(I0, I1);
419 LoopFunc(I1, I0);
420
421 i += 2;
422 } while(i < (num_loop - 2));
423 }
424 // tail
425 if constexpr(TailNum == TailNumber::Even)
426 {
427 b_blockwise_copy.Run(b_grid_desc,
428 b_grid_buf,
430 b_block_origin_idx,
431 b_thread_bufs(I1));
432
433 b_blockwise_copy_up.Run(b_grid_desc,
434 b_grid_buf_up,
436 b_block_origin_idx,
437 b_thread_bufs_up(I1));
438
440 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
441
442 static_for<0, MRepeat, 1>{}([&](auto m0) {
443 static_for<0, NRepeat, 1>{}([&](auto n0) {
444 static_for<0, KRepeat, 1>{}([&](auto k0) {
448
449 static_for<0, KPack, 1>{}([&](auto ik) {
450 a_thread_vec.template AsType<ComputeDataType>()(ik) =
451 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
452 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
453 b_thread_vec.template AsType<ComputeDataType>()(ik) =
454 b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
455 make_tuple(n0, I0, k0, ik))>{}];
456 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
457 b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
458 make_tuple(n0, I0, k0, ik))>{}];
459 });
460
461 using mfma_input_type =
462 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
463
464 constexpr index_t c_offset =
465 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
466
467 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
468 b_thread_vec.template AsType<mfma_input_type>(),
469 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
470 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
471 b_thread_vec_up.template AsType<mfma_input_type>(),
472 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
473 });
474 });
475 });
476
478
479 static_for<0, MRepeat, 1>{}([&](auto m0) {
480 static_for<0, KRepeat, 1>{}([&](auto k0) {
482 make_tuple(m0, I0, I0, k0, I0, I0),
483 a_block_buf,
485 make_tuple(m0, I0, I0, k0, I0, I0),
486 a_thread_buf);
487 });
488 });
489 // B VGPR->VGPR dequant
491 b_block_origin_idx,
492 b_thread_bufs(I1),
494 make_tuple(I0, I0, I0, I0),
495 b_thread_dequant_bufs(I1));
496
498 b_block_origin_idx,
499 b_thread_bufs_up(I1),
501 make_tuple(I0, I0, I0, I0),
502 b_thread_dequant_bufs_up(I1));
503 __builtin_amdgcn_sched_barrier(0);
504
505 static_for<0, MRepeat, 1>{}([&](auto m0) {
506 static_for<0, NRepeat, 1>{}([&](auto n0) {
507 static_for<0, KRepeat, 1>{}([&](auto k0) {
511
512 static_for<0, KPack, 1>{}([&](auto ik) {
513 a_thread_vec.template AsType<ComputeDataType>()(ik) =
514 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
515 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
516 b_thread_vec.template AsType<ComputeDataType>()(ik) =
517 b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
518 make_tuple(n0, I0, k0, ik))>{}];
519 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
520 b_thread_dequant_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
521 make_tuple(n0, I0, k0, ik))>{}];
522 });
523
524 using mfma_input_type =
525 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
526
527 constexpr index_t c_offset =
528 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
529
530 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
531 b_thread_vec.template AsType<mfma_input_type>(),
532 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
533 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
534 b_thread_vec_up.template AsType<mfma_input_type>(),
535 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
536 });
537 });
538 });
539 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
540 // latency
541 // __builtin_amdgcn_sched_barrier(0);
542 }
543 else
544 {
545 static_for<0, MRepeat, 1>{}([&](auto m0) {
546 static_for<0, NRepeat, 1>{}([&](auto n0) {
547 static_for<0, KRepeat, 1>{}([&](auto k0) {
551
552 static_for<0, KPack, 1>{}([&](auto ik) {
553 a_thread_vec.template AsType<ComputeDataType>()(ik) =
554 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
555 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
556 b_thread_vec.template AsType<ComputeDataType>()(ik) =
557 b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
558 make_tuple(n0, I0, k0, ik))>{}];
559 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
560 b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
561 make_tuple(n0, I0, k0, ik))>{}];
562 });
563
564 using mfma_input_type =
565 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
566
567 constexpr index_t c_offset =
568 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
569
570 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
571 b_thread_vec.template AsType<mfma_input_type>(),
572 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
573 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
574 b_thread_vec_up.template AsType<mfma_input_type>(),
575 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
576 });
577 });
578 });
579 }
580 }
581
582 protected:
583 // MRepeat MWave MLane KRepeat KLane KPack
584 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
587
589 ComputeDataType,
591 decltype(a_thread_desc_),
594 5,
595 A_K1,
596 A_K1>;
597
599
602
603 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
604
606
614 Sequence<1, 2, 0, 3>,
615 3,
616 KPack>;
617
620};
621
622} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp:588
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp:102
ThreadwiseTensorSliceTransfer_StaticToStatic< BDataType, ComputeDataType, decltype(b_block_desc_n0_n1_k0_k1), decltype(b_block_desc_n0_n1_k0_k1), tensor_operation::element_wise::PassThrough, Sequence< Number< NRepeat >{}, I1, Number< KRepeat >{}, Number< KPack >{}>, Sequence< 1, 2, 0, 3 >, 3, KPack > BThreadDequantCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp:607
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp:229
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp:37
Definition utility/sequence.hpp:43
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition dtype_vector.hpp:10