20 template <
typename Problem>
25 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(
I0);
26 constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(
I1);
27 if constexpr(MPerXdl == 16 && NPerXdl == 16)
30 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
31 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
49 a_lds_block_desc_permuted,
56 return a_lds_block_desc;
60 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
61 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
77 return a_lds_block_desc;
81 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
82 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
86 constexpr auto DataTypeSize =
sizeof(ADataType);
87 constexpr auto MLdsLayer =
88 (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
92 number<kMPerBlock / MLdsLayer>{},
101 number<kKPerBlock / kKPack * MLdsLayer>{})),
107 a_lds_block_desc_permuted,
116 a_lds_block_desc_xk0_mnldslayer_mn_xk1,
123 return a_lds_block_desc;
136 template <
typename Problem,
typename DataType, index_t MNPerBlock, index_t XPerTile>
139 constexpr index_t BlockSize = Problem::kBlockSize;
140 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
141 constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
146 if constexpr(XPerTile % (PackedSize * 32 /
sizeof(DataType)) == 0 &&
147 elements_per_thread % (PackedSize * 32 /
sizeof(DataType)) == 0 &&
150 return (PackedSize * 32 /
sizeof(DataType));
152 else if constexpr(XPerTile % (PackedSize * 16 /
sizeof(DataType)) == 0 &&
153 elements_per_thread % (PackedSize * 16 /
sizeof(DataType)) == 0)
155 return (PackedSize * 16 /
sizeof(DataType));
157 else if constexpr(XPerTile % (PackedSize * 8 /
sizeof(DataType)) == 0 &&
158 elements_per_thread % (PackedSize * 8 /
sizeof(DataType)) == 0)
160 return (PackedSize * 8 /
sizeof(DataType));
162 else if constexpr(
sizeof(DataType) >= PackedSize * 4 &&
163 XPerTile % (PackedSize * 4 /
sizeof(DataType)) == 0 &&
164 elements_per_thread % (PackedSize * 4 /
sizeof(DataType)) == 0)
166 return (PackedSize * 4 /
sizeof(DataType));
168 else if constexpr(
sizeof(DataType) >= PackedSize * 2 &&
169 XPerTile % (PackedSize * 2 /
sizeof(DataType)) == 0 &&
170 elements_per_thread % (PackedSize * 2 /
sizeof(DataType)) == 0)
172 return (PackedSize * 2 /
sizeof(DataType));
180 template <
typename Problem>
185 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
186 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
188 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
198 template <
typename Problem>
203 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
204 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
206 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
216 template <
typename Problem>
219 constexpr index_t smem_size_a =
sizeof(
typename Problem::ADataType) *
224 template <
typename Problem>
232 template <
typename Problem>
235 return Problem::VectorLoadSize /
sizeof(
typename Problem::ADataType);
238 template <
typename Problem>
241 using TileShape =
typename Problem::BlockGemmShape;
242 if constexpr(TileShape::WarpTile::at(
I1) == 32)
244 return TileShape::WarpTile::at(
I2) / 2;
248 static_assert(TileShape::WarpTile::at(
I1) == 16);
249 return TileShape::WarpTile::at(
I2) / 4;
253 template <
typename Problem>
256 using TileShape =
typename Problem::BlockGemmShape;
259 static_assert(TileShape::BlockWarps::at(
I0) == 1,
"requires Wave_M == 1");
261 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(
I0);
262 constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(
I2);
264 constexpr int Repeat = TileShape::BlockWarps::at(
number<1>{});
267 constexpr int KPerThread = KPerXdl / KLane;
269 constexpr int MaxVecSize = 16 /
sizeof(ADataType);
270 constexpr int KItemsPerLoad =
min(MaxVecSize, KPerThread);
271 constexpr int KFragment = KPerThread / KItemsPerLoad;
283 template <
typename Problem>
289 constexpr index_t BlockSize = Problem::kBlockSize;
291 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
292 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
296 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
298 constexpr index_t M1 = Problem::VectorLoadSize /
sizeof(ADataType) * APackedSize;
299 constexpr index_t M0 = MPerBlock / M1;
300 constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
301 static_assert(total_pixels % M1 == 0);
302 constexpr index_t K3 = total_pixels / M1;
304 static_assert(KPack % K3 == 0);
305 constexpr index_t K2 = KPack / K3;
310 static_assert(KPerBlock == K0 * K1 * K2 * K3);
322 constexpr index_t K2_m = K2 / K1;
324 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
336 constexpr index_t K1 = Problem::VectorLoadSize /
sizeof(ADataType) * APackedSize;
337 constexpr index_t K0 = KPerBlock / K1;
343 static_assert(M2 != 0,
"M2 is zero, which will lead to a division by zero error.");
344 static_assert(M1 != 0,
"M1 is zero, which will lead to a division by zero error.");
345 constexpr index_t M0 = MPerBlock / (M2 * M1);
346 static_assert(M0 * M1 * M2 == MPerBlock,
347 "Incorrect M0, M2, M1 configuration! "
348 "M0, M1, M2 must cover whole MPerBlock!");
362 constexpr index_t M1 = MPerBlock / M0;
376 template <
typename Problem>
381 constexpr index_t BlockSize = Problem::kBlockSize;
384 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
386 constexpr index_t K1 = 16 /
sizeof(ADataType);
387 constexpr index_t K0 = KPerBlock / K1;
390 static_assert(M2 != 0,
"M2 is zero, which will lead to a division by zero error.");
391 static_assert(M1 != 0,
"M1 is zero, which will lead to a division by zero error.");
406 template <
typename Problem>
409 using TileShape =
typename Problem::BlockGemmShape;
411 constexpr index_t BlockSize = Problem::kBlockSize;
413 constexpr index_t WaveNum = BlockSize / WaveSize;
417 constexpr index_t MaxVecSize = 16 /
sizeof(
typename Problem::BDataType);
418 constexpr index_t KItemsPerLoad =
min(KBPerLoad, MaxVecSize);
419 constexpr index_t KFragment = KBPerLoad / KItemsPerLoad;
420 static_assert(KFragment * KItemsPerLoad == KBPerLoad);
422 constexpr index_t KThdPerWave = WaveSize;
423 constexpr index_t KWavePerBlk = 1;
424 static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad,
"wrong");
425 static_assert(TileShape::BlockWarps::at(
number<2>{}) == 1,
"Requires K_Warp == 1");
427 constexpr index_t NBPerLoad = 1;
428 constexpr index_t NThdPerWave = 1;
432 constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
449 template <
typename Problem>
454 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
455 constexpr index_t kBlockSize = Problem::kBlockSize;
456 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
457 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
459 constexpr index_t M1 = Problem::VectorLoadSize /
sizeof(ADataType);
460 constexpr index_t M0 = kMPerBlock / M1;
461 constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
462 static_assert(total_pixels % M1 == 0);
463 constexpr index_t K3 = total_pixels / M1;
465 static_assert(kKPack % K3 == 0);
466 constexpr index_t K2 = kKPack / K3;
468 if constexpr(warp_size >= (K2 * M0))
470 constexpr index_t K1 = warp_size / (K2 * M0);
471 constexpr index_t K0 = kBlockSize / warp_size;
484 constexpr index_t K2_m = K2 / K1;
486 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
497 template <
typename Problem>
501 using BlockWarps =
typename Problem::BlockGemmShape::BlockWarps;
502 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
504 typename Problem::BDataType,
505 typename Problem::CDataType,
509 Problem::TransposeC>;
512 typename Problem::ADataType,
515 typename Problem::BDataType,
516 typename Problem::CDataType,
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:16
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:14
static constexpr auto I2
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr auto GetKBPerLoad()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:239
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegBlockDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:450
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:233
static CK_TILE_HOST_DEVICE constexpr auto GetGlobalVectorLoadSize()
Get the maximum global memory vector load size.
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:137
static CK_TILE_HOST_DEVICE constexpr auto MakeALDS_WarpTileDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:254
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:181
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:284
static CK_TILE_HOST_DEVICE constexpr auto GetBlockFlatmm()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:498
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:225
static constexpr auto I0
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:15
static CK_TILE_HOST_DEVICE constexpr auto MakeADramDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:377
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeB()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:199
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr auto MakeBFlatDramTileDistribution()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:407
static constexpr auto I1
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:217
Definition tile/core/numeric/numeric.hpp:81
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192