28template <
typename GridwiseGemm,
29 typename BatchedGemmArg,
30 bool HasMainKBlockLoop,
35#if CK_USE_LAUNCH_BOUNDS
40#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
43 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
45 const index_t g_idx = blockIdx.z % karg.Batch;
46 const index_t k_idx = blockIdx.z / karg.Batch;
48 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
49 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
50 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
51 const auto b_scale_batch_offset =
52 karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx);
54 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
56 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
57 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
58 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
59 karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset,
60 karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
69template <
typename GridwiseGemm,
70 typename BatchedGemmArg,
71 bool HasMainKBlockLoop,
76#if CK_USE_LAUNCH_BOUNDS
81#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
82 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
86 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89 const index_t g_idx = blockIdx.z % karg.Batch;
90 const index_t k_idx = blockIdx.z / karg.Batch;
92 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
93 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
94 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
95 const auto b_scale_batch_offset =
96 karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx);
98 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
100 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
101 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
102 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
103 karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset,
104 karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
114namespace tensor_operation {
117template <
typename ALayout,
122 typename BScaleDataType,
124 typename GemmAccDataType,
125 typename CShuffleDataType,
126 typename AElementwiseOperation,
127 typename BElementwiseOperation,
128 typename CElementwiseOperation,
142 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
143 typename ABlockTransferThreadClusterArrangeOrder,
144 typename ABlockTransferSrcAccessOrder,
145 index_t ABlockTransferSrcVectorDim,
146 index_t ABlockTransferSrcScalarPerVector,
147 index_t ABlockTransferDstScalarPerVector_AK1,
148 bool ABlockLdsExtraM,
149 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
150 typename BBlockTransferThreadClusterArrangeOrder,
151 typename BBlockTransferSrcAccessOrder,
152 index_t BBlockTransferSrcVectorDim,
153 index_t BBlockTransferSrcScalarPerVector,
154 index_t BBlockTransferDstScalarPerVector_BK1,
155 bool BBlockLdsExtraN,
156 index_t CShuffleMXdlPerWavePerShuffle,
157 index_t CShuffleNXdlPerWavePerShuffle,
158 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
159 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
162 typename ComputeTypeA = CDataType,
163 typename ComputeTypeB = ComputeTypeA,
164 bool PermuteA =
false,
165 bool PermuteB =
false>
176 AElementwiseOperation,
177 BElementwiseOperation,
178 CElementwiseOperation>
185 template <index_t NXdlPerWave_>
195 AElementwiseOperation,
196 BElementwiseOperation,
197 CElementwiseOperation,
211 ABlockTransferThreadClusterLengths_AK0_M_AK1,
212 ABlockTransferThreadClusterArrangeOrder,
213 ABlockTransferSrcAccessOrder,
214 ABlockTransferSrcVectorDim,
215 ABlockTransferSrcScalarPerVector,
216 ABlockTransferDstScalarPerVector_AK1,
219 BBlockTransferThreadClusterLengths_BK0_N_BK1,
220 BBlockTransferThreadClusterArrangeOrder,
221 BBlockTransferSrcAccessOrder,
222 BBlockTransferSrcVectorDim,
223 BBlockTransferSrcScalarPerVector,
224 BBlockTransferDstScalarPerVector_BK1,
227 CShuffleMXdlPerWavePerShuffle,
228 CShuffleNXdlPerWavePerShuffle,
229 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
230 CShuffleBlockTransferScalarPerVector_NPerBlock,
260 : BatchStrideA_(BatchStrideA),
261 BatchStrideB_(BatchStrideB),
262 BatchStrideC_(BatchStrideC),
263 BatchStrideScaleB_(BatchStrideScaleB)
269 return g_idx *
static_cast<long_index_t>(BatchStrideA_);
279 return g_idx *
static_cast<long_index_t>(BatchStrideC_);
283 return g_idx *
static_cast<long_index_t>(BatchStrideScaleB_);
293 template <
typename Gr
idwiseGemm>
300 const BDataType* p_b_grid_,
301 CDataType* p_c_grid_,
313 const BScaleDataType* p_b_scale_grid_,
316 AElementwiseOperation a_element_op_,
317 BElementwiseOperation b_element_op_,
318 CElementwiseOperation c_element_op_)
336 BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_)
345 template <
typename Gr
idwiseGemm>
350 if(stream_config.log_level_ > 0)
355 if(!GridwiseGemm::CheckValidity(arg))
357 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
361 std::tie(gdx, gdy, gdz) =
362 GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.
Batch * arg.KBatch);
366 index_t k_grain = arg.KBatch * KPerBlock;
367 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
369 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
371 const auto Run = [&](
const auto& kernel) {
372 if(stream_config.flush_cache)
374 DeviceArgument arg_ = arg;
376 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
377 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
378 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
379 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
381 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
383 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
387 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
388 rotating_mem.Print();
390 auto run_flush_cache = [&]() {
397 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
399 arg_.M * arg_.N *
sizeof(CDataType),
400 stream_config.stream_id_));
415 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
417 arg.M * arg.N *
sizeof(CDataType),
418 stream_config.stream_id_));
421 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
425 constexpr index_t minimum_occupancy =
428 MPerBlock * NPerBlock * KPerBlock *
sizeof(ADataType) <= 128 * 128 * 64 * 2)
433 if(has_main_k_block_loop)
465 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
476 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
489 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
491 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
504 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
506 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
520 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
522 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
536 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
538 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
552 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
554 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
567 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
569 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
585 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
596 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
609 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
611 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
624 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
626 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
640 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
642 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
656 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
658 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
672 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
674 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
687 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
689 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
709 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
734 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
762 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
787 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
856 using Argument32 = ArgumentBase<GridwiseGemm32>;
868 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
929 const BDataType* p_b,
942 const BScaleDataType* p_b_scale,
945 AElementwiseOperation a_element_op,
946 BElementwiseOperation b_element_op,
947 CElementwiseOperation c_element_op)
988 const void* p_b_scale,
991 AElementwiseOperation a_element_op,
992 BElementwiseOperation b_element_op,
993 CElementwiseOperation c_element_op)
override
995 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
996 static_cast<const BDataType*
>(p_b),
997 static_cast<CDataType*
>(p_c),
1009 static_cast<const BScaleDataType*
>(p_b_scale),
1020 return std::make_unique<Invoker>(
Invoker{});
1026 auto str = std::stringstream();
1028 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
1032 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
1040 str <<
"DeviceGemmXdlUniversal"
1043 << std::string(ALayout::name)[0]
1044 << std::string(BLayout::name)[0]
1045 << std::string(CLayout::name)[0]
1048 << BlockSize <<
", "
1050 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
1052 << MPerXDL<<
"x"<<NPerXDL <<
", "
1054 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
1056 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
1057 <<
"BlkGemmPipelineScheduler: "
1058 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
1059 <<
"BlkGemmPipelineVersion: "
1060 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
1061 <<
"BlkGemmPipelinePrefetchStages: "
1062 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:38
int64_t long_index_t
Definition ck.hpp:300
__global__ void kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:79
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Argument::Argument __host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:717
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
Definition data_type.hpp:187
Definition device_base.hpp:197
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:255
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:256
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:272
__host__ __device__ constexpr long_index_t GetSacleBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:281
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:267
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:277
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:295
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:297
ArgumentBase(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, index_t BatchStrideA_, index_t BatchStrideB_, index_t BatchStrideC_, index_t BatchStrideScaleB_, const BScaleDataType *p_b_scale_grid_, index_t Batch_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:299
index_t Batch
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:296
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:344
float RunImp(const ArgumentBase< GridwiseGemm > &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:346
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:843
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:865
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:179
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:1018
static constexpr index_t APackedSize
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:240
static auto MakeInvoker()
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:971
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:181
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:919
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const void *p_b_scale, index_t Batch, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:974
static constexpr index_t BPackedSize
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:247
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const BScaleDataType *p_b_scale, index_t Batch, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:928
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:238
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:182
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:872
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:186
index_t GetKPerBlock() override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:924
std::string GetTypeString() const override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:1024
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:237
bool GetPermuteB() override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:926
ArgumentBase< GridwiseGemm64 > Argument
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:340
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:878
Definition device_batched_gemm.hpp:60
Definition flush_cache.hpp:299