mx_flatmm_kernel.hpp Source File

mx_flatmm_kernel.hpp Source File#

Composable Kernel: mx_flatmm_kernel.hpp Source File
mx_flatmm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
11
13
14namespace ck_tile {
15
16template <typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
17struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
18{
20
31 static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
32 static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
33
36 // Below type is actually accumulation data type - the output of block GEMM.
38
39 static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
40 static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
41 static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
42
45
46 static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
47 static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
48 static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
49
50 static constexpr index_t NumDTensor = DsDataType::size();
51
52 static constexpr auto I0 = number<0>();
53 static constexpr auto I1 = number<1>();
54 static constexpr auto I2 = number<2>();
55 static constexpr auto I3 = number<3>();
56 static constexpr auto I4 = number<4>();
57 static constexpr auto I5 = number<5>();
58
59 static_assert(DsLayout::size() == DsDataType::size(),
60 "The size of DsLayout and DsDataType should be the same");
61 // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
62
63 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
64 {
65 // clang-format off
66 return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
67 // clang-format on
68 }
69
70 template <class ScaleM, class ScaleN>
71 CK_TILE_HOST static constexpr auto
72 GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
73 {
74 if constexpr(UsePersistentKernel)
75 {
76 hipDeviceProp_t prop;
77 int deviceId = 0; // default device
78
79 constexpr int block_size = MXFlatmmKernel::BlockSize().x;
80 int dync_smem_size = 0;
81 int maxActiveBlocksPerCU = 0;
82
83 if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
84 throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
85 hipGetErrorName(hipGetLastError()));
86
87 if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
88 &maxActiveBlocksPerCU,
89 reinterpret_cast<void*>(
90 kentry<1, MXFlatmmKernel, remove_cvref_t<decltype(kargs)>>),
91 block_size,
92 dync_smem_size) != hipSuccess)
93 throw std::runtime_error(
94 std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
95 hipGetErrorName(hipGetLastError()));
96
97 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
98 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
99
100 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
101 // << ", persistent_block_size: " << persistent_block_size
102 // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
103
104 if(kargs.k_batch != 1)
105 throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel");
106 return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
107 }
108 else
109 {
110 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
111 }
112 }
113
114 using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
115
116 template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
117 CK_TILE_DEVICE static auto
119 const BDataType* b_flat_ptr,
120 const std::array<const void*, NumDTensor>& ds_ptr,
121 EDataType* e_ptr,
122 const KernelArgs& kargs,
123 const SplitKBatchOffset& splitk_batch_offset)
124 {
125 const auto& a_tensor_view = [&]() {
126 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
127 {
129 a_ptr,
130 make_tuple(kargs.M, splitk_batch_offset.splitted_k),
131 make_tuple(kargs.stride_A, 1),
132 number<FlatmmPipeline::GetVectorSizeA()>{},
133 number<1>{});
134 }
135 else
136 {
138 a_ptr,
139 make_tuple(splitk_batch_offset.splitted_k, kargs.M),
140 make_tuple(kargs.stride_A, 1),
141 number<FlatmmPipeline::GetVectorSizeA()>{},
142 number<1>{});
143 }
144 }();
145
146 index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
147 index_t kFlatN = kargs.N * kargs.K / kFlatK;
148
149 const auto& b_flat_tensor_view = [&]() {
151 b_flat_ptr,
152 make_tuple(kFlatN, kFlatK),
153 make_tuple(kFlatK, 1),
154 number<FlatmmPipeline::GetVectorSizeB()>{},
155 number<1>{});
156 }();
157
158 const auto& ds_tensor_view = generate_tuple(
159 [&](auto i) {
160 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
161 using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
162 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
163 {
165 static_cast<const DDataType_*>(ds_ptr[i]),
166 make_tuple(kargs.M, kargs.N),
167 make_tuple(kargs.stride_Ds[i], 1),
168 number<EpiloguePipeline::GetVectorSizeD(i)>{},
169 number<1>{});
170 }
171 else
172 {
174 static_cast<const DDataType_*>(ds_ptr[i]),
175 make_tuple(kargs.N, kargs.M),
176 make_tuple(kargs.stride_Ds[i], 1),
177 number<EpiloguePipeline::GetVectorSizeD(i)>{},
178 number<1>{});
179 }
180 },
182
183 // TODO: enable vector write for C in ColMajor
184 const auto& e_tensor_view = [&]() {
185 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
186 {
188 e_ptr,
189 make_tuple(kargs.M, kargs.N),
190 make_tuple(kargs.stride_E, 1),
191 number<EpiloguePipeline::GetVectorSizeC()>{},
192 number<1>{});
193 }
194 else
195 {
197 e_ptr,
198 make_tuple(kargs.N, kargs.M),
199 make_tuple(kargs.stride_E, 1),
200 number<1>{},
201 number<1>{});
202 }
203 }();
204
205 auto scale_a = kargs.scale_m_ptr;
206 auto scale_b = kargs.scale_n_ptr;
207
208 static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
209 const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
210 const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
211 const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
212
213 // A scale tensor view
214 const auto& scale_a_tensor_view = [&]() {
215 // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
216 const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
217 make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
218 const auto scale_a_desc = transform_tensor_descriptor(
219 scale_a_naive_desc,
224
226 reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
227 }();
228
229 // B scale tensor view
230 const auto& scale_b_tensor_view = [&]() {
231 const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
232 make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
233 const auto scale_b_desc = transform_tensor_descriptor(
234 scale_b_navie_desc,
239
241 reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
242 }();
243
244 return make_tuple(a_tensor_view,
245 b_flat_tensor_view,
246 ds_tensor_view,
247 e_tensor_view,
248 scale_a_tensor_view,
249 scale_b_tensor_view);
250 }
251
252 template <typename TensorView>
253 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
254 {
255 const auto& a_pad_view = [&]() {
256 const auto& a_tensor_view = views.at(I0);
257 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
258 {
259 return pad_tensor_view(a_tensor_view,
263 }
264 else
265 {
266 return pad_tensor_view(a_tensor_view,
270 }
271 }();
272
273 const auto& b_flat_tensor_view = views.at(I1);
274
275 const auto& ds_pad_view = generate_tuple(
276 [&](auto i) {
277 const auto& d_tensor_view = views.at(I2);
278 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
279 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
280 {
281 return pad_tensor_view(d_tensor_view[i],
285 }
286 else
287 {
288 return pad_tensor_view(d_tensor_view[i],
292 }
293 },
295
296 // TODO vector write in for C in ColMajor
297 const auto& e_pad_view = [&]() {
298 const auto& e_tensor_view = views.at(I3);
299 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
300 {
301 return pad_tensor_view(e_tensor_view,
305 }
306 else
307 {
308 return pad_tensor_view(e_tensor_view,
312 }
313 }();
314
315 return make_tuple(
316 a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
317 }
318
319 template <typename PadView>
320 CK_TILE_DEVICE static auto
321 MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
322 {
323 const auto& a_pad_view = views.at(I0);
324 const auto& b_flat_pad_view = views.at(I1);
325 const auto& ds_pad_view = views.at(I2);
326 const auto& e_pad_view = views.at(I3);
327
328 const auto& a_block_window = [&]() {
329 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
330 {
331 return make_tile_window(a_pad_view,
334 {i_m, 0});
335 }
336 else
337 {
338 return make_tile_window(a_pad_view,
341 {0, i_m});
342 }
343 }();
344
345 const auto& b_flat_block_window =
346 make_tile_window(b_flat_pad_view,
349 {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
350
351 const auto ds_block_window = generate_tuple(
352 [&](auto i) {
353 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
354 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
355 {
356 return make_tile_window(ds_pad_view[i],
359 {i_m, i_n});
360 }
361 else
362 {
363 return make_tile_window(ds_pad_view[i],
366 {i_n, i_m});
367 }
368 },
370
371 auto e_block_window = make_tile_window(
372 e_pad_view,
374 {i_m, i_n});
375
376 static constexpr int BlockScaleSize = 32;
377
378 auto scale_a_block_window = make_tile_window(
379 views.at(I4),
381 number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
382 {i_m / MXdlPack, 0});
383
384 auto scale_b_block_window = make_tile_window(
385 views.at(I5),
387 number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
388 {i_n / NXdlPack, 0});
389
390 return make_tuple(a_block_window,
391 b_flat_block_window,
392 ds_block_window,
393 e_block_window,
394 scale_a_block_window,
395 scale_b_block_window);
396 }
397
398 template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
399 CK_TILE_DEVICE static void
400 RunFlatmm(const ADataType* a_ptr,
401 const BDataType* b_flat_ptr,
402 const std::array<const void*, NumDTensor>& ds_ptr,
403 EDataType* e_ptr,
404 void* smem_ptr_ping,
405 void* smem_ptr_pong,
406 const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
407 const SplitKBatchOffset& splitk_batch_offset,
408 const index_t block_idx_m,
409 const index_t block_idx_n)
410 {
411 // Create Gemm tensor views, pad views and tile windows
412 const auto& gemm_tensor_views_tuple =
414 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
415 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
416 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
417
418 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
419
420 // Run GEMM cooperatively by whole workgroup.
421 const auto& a_block_window = gemm_tile_windows.at(I0);
422 const auto& b_flat_block_window = gemm_tile_windows.at(I1);
423 const auto& d_block_window = gemm_tile_windows.at(I2);
424 const auto& scale_a_block_window = gemm_tile_windows.at(I4);
425 const auto& scale_b_block_window = gemm_tile_windows.at(I5);
426
427 static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
428 || ScaleM::GranularityMN == -1 // or ScaleA is disable
429 || ScaleN::GranularityMN == -1, // or ScaleB is disable
430 "ScaleM and ScaleN should have the same GranularityK");
431 constexpr bool DoEpiScale =
432 (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
433 (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
434
435 auto a_block_window_with_distr =
436 ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
437 a_block_window.get_window_lengths(),
438 a_block_window.get_window_origin(),
439 FlatmmPipeline::GetADramTileDistribution());
440 const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
441 b_flat_block_window,
442 scale_a_block_window,
443 scale_b_block_window,
444 num_loop,
445 smem_ptr_ping,
446 smem_ptr_pong);
447
448 // Run Epilogue Pipeline
449 if constexpr(DoEpiScale)
450 {
451 auto& c_block_window = gemm_tile_windows.at(I3);
452 EpiloguePipeline{}(c_block_window,
453 c_block_tile,
454 d_block_window,
455 smem_ptr_ping,
456 kargs.scale_m_ptr + block_idx_m,
457 kargs.scale_n_ptr + block_idx_n);
458 }
459 else if(UseDefaultScheduler || (get_warp_id() == 0))
460 {
461 // Run Epilogue Pipeline
462 auto& c_block_window = gemm_tile_windows.at(I3);
463 EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
464 }
465 }
466
467 template <class ScaleM, class ScaleN>
468 CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
469 int partition_idx = blockIdx.x) const
470 {
471 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
472
473 do
474 {
475 const auto [iM, iN] =
476 TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
477 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
478 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
479
480 const SplitKBatchOffset splitk_batch_offset(kargs);
481 // options
482 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
483 splitk_batch_offset.a_k_split_offset / APackedSize;
484 const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
485 splitk_batch_offset.b_k_split_offset / BPackedSize;
486 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
487
488 // allocate LDS
489 __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
490 __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
491
492 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
493 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
495 {
496 constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
498 b_flat_ptr,
499 kargs.ds_ptr,
500 e_ptr,
501 smem_ptr_ping,
502 smem_ptr_pong,
503 kargs,
504 splitk_batch_offset,
505 i_m,
506 i_n);
507 }
508 else
509 {
510 static_assert(false,
511 "Unimplemented: atomic_add with odd vector size for fp16/bf16");
512 }
513 partition_idx += gridDim.x;
514 } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
515 }
516};
517
518} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
@ atomic_add
Definition arch.hpp:58
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition flatmm_kernel.hpp:229
Definition flatmm_kernel.hpp:249
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPongSize()
Definition flatmm_kernel.hpp:356
static CK_TILE_HOST constexpr auto BlockSize()
Definition flatmm_kernel.hpp:330
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPingSize()
Definition flatmm_kernel.hpp:352
Definition mx_flatmm_kernel.hpp:18
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition mx_flatmm_kernel.hpp:28
static CK_TILE_HOST const std::string GetName()
Definition mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition mx_flatmm_kernel.hpp:25
FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ > Underlying
Definition mx_flatmm_kernel.hpp:19
static constexpr index_t NumDTensor
Definition mx_flatmm_kernel.hpp:50
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition mx_flatmm_kernel.hpp:29
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition mx_flatmm_kernel.hpp:27
static constexpr int NThreadPerXdl
Definition mx_flatmm_kernel.hpp:40
static constexpr int NXdlPack
Definition mx_flatmm_kernel.hpp:47
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition mx_flatmm_kernel.hpp:118
static constexpr auto I2
Definition mx_flatmm_kernel.hpp:54
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition mx_flatmm_kernel.hpp:32
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition mx_flatmm_kernel.hpp:321
static constexpr auto I4
Definition mx_flatmm_kernel.hpp:56
static constexpr auto I1
Definition mx_flatmm_kernel.hpp:53
static constexpr int MXdlPack
Definition mx_flatmm_kernel.hpp:46
static constexpr auto I0
Definition mx_flatmm_kernel.hpp:52
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition mx_flatmm_kernel.hpp:26
static CK_TILE_HOST constexpr auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition mx_flatmm_kernel.hpp:72
static constexpr int MThreadPerXdl
Definition mx_flatmm_kernel.hpp:39
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition mx_flatmm_kernel.hpp:21
static constexpr auto I3
Definition mx_flatmm_kernel.hpp:55
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition mx_flatmm_kernel.hpp:468
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition mx_flatmm_kernel.hpp:30
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition mx_flatmm_kernel.hpp:35
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition mx_flatmm_kernel.hpp:253
static constexpr index_t KernelBlockSize
Definition mx_flatmm_kernel.hpp:31
remove_cvref_t< MXFlatmmPipeline_ > FlatmmPipeline
Definition mx_flatmm_kernel.hpp:22
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition mx_flatmm_kernel.hpp:400
static constexpr int KXdlPack
Definition mx_flatmm_kernel.hpp:48
static constexpr auto I5
Definition mx_flatmm_kernel.hpp:57
static constexpr int KThreadPerXdl
Definition mx_flatmm_kernel.hpp:41
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition mx_flatmm_kernel.hpp:23
static constexpr int BPackedSize
Definition mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition mx_flatmm_kernel.hpp:43
Definition type_traits.hpp:115
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49