device_moe_gemm.hpp Source File

device_moe_gemm.hpp Source File#

Composable Kernel: device_moe_gemm.hpp Source File
device_moe_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename BDataType,
30 typename DsDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t BlockSize,
39 index_t MPerBlock,
40 index_t NPerBlock,
41 index_t KPerBlock,
42 index_t AK1,
43 index_t BK1,
44 index_t MPerXDL,
45 index_t NPerXDL,
46 index_t MXdlPerWave,
47 index_t NXdlPerWave,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
54 bool ABlockLdsExtraM,
55 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56 typename BBlockTransferThreadClusterArrangeOrder,
57 typename BBlockTransferSrcAccessOrder,
58 index_t BBlockTransferSrcVectorDim,
59 index_t BBlockTransferSrcScalarPerVector,
60 index_t BBlockTransferDstScalarPerVector_BK1,
61 bool BBlockLdsExtraN,
62 index_t CShuffleMXdlPerWavePerShuffle,
63 index_t CShuffleNXdlPerWavePerShuffle,
64 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65 typename CDEShuffleBlockTransferScalarPerVectors,
68 index_t ActivationOP = 0,
69 bool NSwizzle = false,
70 bool IsInputGemm = true,
71 bool MulRoutedWeight = true,
72 bool PerTokenQuant = true,
73 typename IndexType = index_t,
74 typename ComputeTypeA = CDataType,
75 typename ComputeTypeB = ComputeTypeA,
76 typename LDSTypeA = ComputeTypeA,
77 typename LDSTypeB = ComputeTypeB>
79 BLayout,
80 DsLayout,
81 CLayout,
82 ADataType,
83 BDataType,
84 DsDataType,
85 CDataType,
86 AElementwiseOperation,
87 BElementwiseOperation,
88 CElementwiseOperation>
89{
91 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
92 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
93 static constexpr index_t NumDTensor = DsDataType::Size();
94 template <index_t NXdlPerWave_>
96 GridwiseMoeGemm<ALayout,
97 BLayout,
98 DsLayout,
99 CLayout,
100 ADataType,
101 BDataType,
102 GemmAccDataType,
103 CShuffleDataType,
104 DsDataType,
105 CDataType,
106 AElementwiseOperation,
107 BElementwiseOperation,
108 CElementwiseOperation,
109 GemmSpec,
110 BlockSize,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 AK1,
115 BK1,
116 MPerXDL,
117 NPerXDL,
118 MXdlPerWave,
119 NXdlPerWave_,
120 ABlockTransferThreadClusterLengths_AK0_M_AK1,
121 ABlockTransferThreadClusterArrangeOrder,
122 ABlockTransferSrcAccessOrder,
123 ABlockTransferSrcVectorDim,
124 ABlockTransferSrcScalarPerVector,
125 ABlockTransferDstScalarPerVector_AK1,
126 false,
127 ABlockLdsExtraM,
128 BBlockTransferThreadClusterLengths_BK0_N_BK1,
129 BBlockTransferThreadClusterArrangeOrder,
130 BBlockTransferSrcAccessOrder,
131 BBlockTransferSrcVectorDim,
132 BBlockTransferSrcScalarPerVector,
133 BBlockTransferDstScalarPerVector_BK1,
134 false,
135 BBlockLdsExtraN,
136 CShuffleMXdlPerWavePerShuffle,
137 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
138 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
139 CDEShuffleBlockTransferScalarPerVectors,
140 BlkGemmPipeSched,
141 BlkGemmPipelineVer,
142 ActivationOP,
143 NSwizzle,
144 IsInputGemm,
145 MulRoutedWeight,
146 PerTokenQuant,
147 IndexType,
148 ComputeTypeA,
149 ComputeTypeB,
150 LDSTypeA,
151 LDSTypeB>;
154
155 using Argument = typename GridwiseGemm64::Argument;
156
157 static constexpr index_t APackedSize = []() {
159 return 2;
160 else
161 return 1;
162 }();
163
164 static constexpr index_t BPackedSize = []() {
166 return 2;
167 else
168 return 1;
169 }();
170
171 int GetPreShuffleParameters() override { return NPerXDL; }
172
173 // Invoker
174 struct Invoker : public BaseInvoker
175 {
176 template <typename GridwiseGemm>
177 float RunImp(const typename GridwiseGemm::Argument& arg,
178 const StreamConfig& stream_config = StreamConfig{})
179 {
180 if(stream_config.log_level_ > 0)
181 {
182 arg.Print();
183 }
184
185 if(!GridwiseGemm::CheckValidity(arg))
186 {
187 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
188 }
189
190 index_t gdx, gdy, gdz;
191 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
192
193 float ave_time = 0;
194
195 index_t k_grain = arg.KBatch * KPerBlock;
196 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
197
198 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
199
200 const auto RunKernel = [&](const auto& kernel) {
201 if(stream_config.flush_cache)
202 {
203
204 std::array<std::size_t, NumDTensor> DsSize;
205
206 auto arg_ = arg;
207
208 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
209 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
210 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
211 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
212
213 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
214 sizeof(ADataType) / APackedSize;
215 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
216 sizeof(BDataType) / BPackedSize;
217
218 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
219 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
220
221 static_for<0, NumDTensor, 1>{}([&](auto i) {
222 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
223 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
224 });
225 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
226 DsDataType>
227 rotating_mem(arg_,
228 stream_config.rotating_count,
229 size_a_buffer,
230 size_b_buffer,
231 DsSize);
232 rotating_mem.Print();
233
234 auto run_flush_cache = [&]() {
235 // flush icache
237 // rotating mem
238 rotating_mem.Next();
239 // clear c mem
240 if(arg_.KBatch > 1)
241 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
242 0,
243 arg_.M * arg_.N * sizeof(CDataType),
244 stream_config.stream_id_));
245 };
246
248 stream_config,
249 run_flush_cache,
250 kernel,
251 dim3(gdx, gdy, gdz),
252 dim3(BlockSize),
253 0,
254 arg_);
255 }
256 else
257 {
258 if(arg.KBatch > 1)
259 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
260 0,
261 arg.M * arg.N * sizeof(CDataType),
262 stream_config.stream_id_));
263
264 ave_time = launch_and_time_kernel(
265 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
266 }
267 };
268
269 constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
270 4 * (1 + GridwiseGemm::NWave);
271 constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
272 4 * (2) * (IsInputGemm ? 2 : 1);
273 constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
274 BlockSize / 4 * (IsInputGemm ? 2 : 1);
275 constexpr auto estimated_reg_total =
276 estimated_reg_a + estimated_reg_b + estimated_reg_c;
277
278 constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
279
280 constexpr auto MemoryDataOp =
282 if(has_main_k_block_loop)
283 {
284 // Tail number always full
285 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
286 {
287 {
288 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
289 {
290 const auto kernel = kernel_moe_gemm<GridwiseGemm,
291 true,
292 MemoryDataOp,
293 minimum_occupancy,
295 RunKernel(kernel);
296 }
297 else
298 {
299 const auto kernel = kernel_moe_gemm<GridwiseGemm,
300 true,
301 MemoryDataOp,
302 minimum_occupancy,
304 RunKernel(kernel);
305 }
306 }
307 }
308 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
309 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
310 {
311 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
312 {
313 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
314 true,
315 MemoryDataOp,
316 minimum_occupancy,
318 RunKernel(kernel);
319 }
320 else
321 {
322 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
323 true,
324 MemoryDataOp,
325 minimum_occupancy,
327 RunKernel(kernel);
328 }
329 }
330 else
331 {
332 throw std::runtime_error("todo: only v1 & v2 support now");
333 }
334 }
335#if 1
336 else
337 {
338 // Tail number always 1
339 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
340 {
341 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
342 {
343 const auto kernel = kernel_moe_gemm<GridwiseGemm,
344 false,
345 MemoryDataOp,
346 minimum_occupancy,
348 RunKernel(kernel);
349 }
350 else
351 {
352 const auto kernel = kernel_moe_gemm<GridwiseGemm,
353 false,
354 MemoryDataOp,
355 minimum_occupancy,
357 RunKernel(kernel);
358 }
359 }
360 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
361 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
362 {
363 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
364 {
365 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
366 false,
367 MemoryDataOp,
368 minimum_occupancy,
370 RunKernel(kernel);
371 }
372 else
373 {
374 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
375 false,
376 MemoryDataOp,
377 minimum_occupancy,
379 RunKernel(kernel);
380 }
381 }
382 else
383 {
384 throw std::runtime_error("todo: only v1 & v2 support now");
385 }
386 }
387#endif
388
389 return ave_time;
390 }
391
393
394 // polymorphic
395 float Run(const BaseArgument* p_arg,
396 const StreamConfig& stream_config = StreamConfig{}) override
397 {
398 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
399 }
400 };
401
402 static constexpr bool IsValidCompilationParameter()
403 {
404 // TODO: properly implement this check
405 return true;
406 }
407
408 static bool IsSupportedArgument(const Argument& arg)
409 {
410 // only impl kbatch 1 now
411 if(arg.KBatch > 1)
412 {
413 return false;
414 }
416 {
417 return false;
418 }
419 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
420 {
421 return false;
422 }
423
424 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
425 GemmSpec == GemmSpecialization::NKPadding ||
426 GemmSpec == GemmSpecialization::MNKPadding ||
427 GemmSpec == GemmSpecialization::KPadding))
428 {
429 return false;
430 }
431 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
432 {
433 return false;
434 }
435 if(get_warp_size() == 64)
436 {
437 if constexpr(NXdlPerWave64 > 0)
438 {
440 }
441 }
442 else
443 {
444 if constexpr(NXdlPerWave32 > 0)
445 {
447 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
448 }
449 }
450 return false;
451 }
452
453 // polymorphic
454 bool IsSupportedArgument(const BaseArgument* p_arg) override
455 {
456 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
457 }
458
459 static auto MakeArgument(const void* p_sorted_token_ids,
460 const void* p_sorted_expert_ids,
461 const void* p_max_token_id,
462 const void* p_a,
463 const void* p_b,
464 std::array<const void*, NumDTensor> p_ds,
465 void* p_c,
466 index_t NumTokens,
467 index_t TopK,
468 index_t M,
469 index_t N,
470 index_t K,
471 index_t StrideA,
472 index_t StrideB,
473 std::array<index_t, NumDTensor> StrideDs,
474 index_t StrideC,
475 index_t KBatch,
476 AElementwiseOperation a_element_op,
477 BElementwiseOperation b_element_op,
478 CElementwiseOperation c_element_op)
479 {
480 return Argument{static_cast<const index_t*>(p_sorted_token_ids),
481 static_cast<const index_t*>(p_sorted_expert_ids),
482 static_cast<const index_t*>(p_max_token_id),
483 static_cast<const ADataType*>(p_a),
484 static_cast<const BDataType*>(p_b),
485 p_ds,
486 static_cast<CDataType*>(p_c),
487 NumTokens,
488 TopK,
489 M,
490 N,
491 K,
492 StrideA,
493 StrideB,
494 StrideDs,
495 StrideC,
496 KBatch,
497 a_element_op,
498 b_element_op,
499 c_element_op};
500 }
501
502 static auto MakeInvoker() { return Invoker{}; }
503
504 // polymorphic
505 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
506 const void* p_b,
507 std::array<const void*, NumDTensor> p_ds,
508 void* p_c,
509 index_t M,
510 index_t N,
511 index_t K,
512 index_t StrideA,
513 index_t StrideB,
514 std::array<ck::index_t, NumDTensor> StrideDs,
515 index_t StrideC,
516 index_t KBatch,
517 AElementwiseOperation a_element_op,
518 BElementwiseOperation b_element_op,
519 CElementwiseOperation c_element_op) override
520 {
521 return std::make_unique<Argument>(nullptr,
522 nullptr,
523 nullptr,
524 static_cast<const ADataType*>(p_a),
525 static_cast<const BDataType*>(p_b),
526 p_ds,
527 static_cast<CDataType*>(p_c),
528 M, // randoms set, no use
529 0,
530 M,
531 N,
532 K,
533 StrideA,
534 StrideB,
535 StrideDs,
536 StrideC,
537 KBatch,
538 a_element_op,
539 b_element_op,
540 c_element_op);
541 }
542
543 // polymorphic
544 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
545 {
546 return std::make_unique<Invoker>(Invoker{});
547 }
548
549 // polymorphic
550 std::string GetTypeString() const override
551 {
552 auto str = std::stringstream();
553
554 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
557
558 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
560
561 // clang-format off
562 str << "DeviceMoeGEmm"
563 << "<"
564 << getGemmSpecializationString(GemmSpec) << ", "
565 << std::string(ALayout::name)[0]
566 << std::string(BLayout::name)[0]
567 << std::string(CLayout::name)[0]
568 << ">"
569 << " BlkSize: "
570 << BlockSize << ", "
571 << "BlkTile: "
572 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
573 << "WaveTile: "
574 << MPerXDL<<"x"<<NPerXDL << ", "
575 << "WaveMap: "
576 << MXdlPerWave<<"x" << NXdlPerWave<<", "
577 << "VmemReadVec: "
578 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
579 << "BlkGemmPipelineScheduler: "
580 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
581 << "BlkGemmPipelineVersion: "
582 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
583 << "BlkGemmPipelinePrefetchStages: "
584 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
585 // clang-format on
586
587 return str.str();
588 }
589};
590
591} // namespace device
592} // namespace tensor_operation
593} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:46
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_moe_gemm.hpp:171
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_moe_gemm.hpp:395
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_moe_gemm.hpp:177
Definition device_moe_gemm.hpp:89
int GetPreShuffleParameters() override
Definition device_moe_gemm.hpp:171
GridwiseMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_moe_gemm.hpp:95
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_moe_gemm.hpp:544
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_moe_gemm.hpp:459
static constexpr bool IsValidCompilationParameter()
Definition device_moe_gemm.hpp:402
static constexpr index_t BPackedSize
Definition device_moe_gemm.hpp:164
static bool IsSupportedArgument(const Argument &arg)
Definition device_moe_gemm.hpp:408
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_moe_gemm.hpp:454
static constexpr auto NXdlPerWave32
Definition device_moe_gemm.hpp:92
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_moe_gemm.hpp:153
typename GridwiseGemm64::Argument Argument
Definition device_moe_gemm.hpp:155
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_moe_gemm.hpp:91
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_moe_gemm.hpp:152
static constexpr index_t NumDTensor
Definition device_moe_gemm.hpp:93
std::string GetTypeString() const override
Definition device_moe_gemm.hpp:550
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_moe_gemm.hpp:505
static auto MakeInvoker()
Definition device_moe_gemm.hpp:502
static constexpr index_t APackedSize
Definition device_moe_gemm.hpp:157
Definition flush_cache.hpp:174