17 template <
typename Problem>
21 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
22 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
26 constexpr auto DataTypeSize =
sizeof(ADataType);
27 constexpr auto MLdsLayer =
28 (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
32 number<kMPerBlock / MLdsLayer>{},
41 number<kKPerBlock / kKPack * MLdsLayer>{})),
47 a_lds_block_desc_permuted,
56 a_lds_block_desc_xk0_mnldslayer_mn_xk1,
63 return a_lds_block_desc;
66 template <
typename Problem>
69 constexpr index_t smem_size_a =
sizeof(
typename Problem::ADataType) *
74 template <
typename Problem>
82 template <
typename Problem>
85 return Problem::VectorLoadSize /
sizeof(
typename Problem::ADataType);
88 template <
typename Problem>
91 using TileShape =
typename Problem::BlockGemmShape;
97 if constexpr(TileShape::WarpTile::at(
I1) == 32)
99 return TileShape::WarpTile::at(
I2) * scale / 2;
103 static_assert(TileShape::WarpTile::at(
I1) == 16);
104 return TileShape::WarpTile::at(
I2) * scale / 4;
108 template <
typename Problem>
114 constexpr index_t BlockSize = Problem::kBlockSize;
116 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
117 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
119 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
121 constexpr index_t M1 = Problem::VectorLoadSize /
sizeof(ADataType);
122 constexpr index_t M0 = MPerBlock / M1;
123 constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
124 static_assert(total_pixels % M1 == 0);
125 constexpr index_t K3 = total_pixels / M1;
127 static_assert(KPack % K3 == 0);
128 constexpr index_t K2 = KPack / K3;
133 static_assert(KPerBlock == K0 * K1 * K2 * K3);
145 constexpr index_t K2_m = K2 / K1;
147 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
159 constexpr index_t K1 = Problem::VectorLoadSize /
sizeof(ADataType);
160 constexpr index_t K0 = KPerBlock / K1;
166 static_assert(M2 != 0,
"M2 is zero, which will lead to a division by zero error.");
167 static_assert(M1 != 0,
"M1 is zero, which will lead to a division by zero error.");
168 constexpr index_t M0 = MPerBlock / (M2 * M1);
169 static_assert(M0 * M1 * M2 == MPerBlock,
170 "Incorrect M0, M2, M1 configuration! "
171 "M0, M1, M2 must cover whole MPerBlock!");
184 constexpr index_t M1 = MPerBlock / (M2 * M0);
185 static_assert(M0 * M1 * M2 == MPerBlock,
186 "Incorrect M0, M1, M2 configuration! "
187 "M0, M1, M2 must cover whole MPerBlock!");
199 template <
typename Problem>
202 using TileShape =
typename Problem::BlockGemmShape;
204 constexpr index_t BlockSize = Problem::kBlockSize;
206 constexpr index_t WaveNum = BlockSize / WaveSize;
209#if defined(__gfx11__)
210 constexpr index_t KRepeatInWave = 2;
212 constexpr index_t KRepeatInWave = 1;
214 constexpr index_t KThdPerWave = WaveSize / KRepeatInWave;
215 constexpr index_t KWavePerBlk = 1;
217 static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad,
"wrong");
219 constexpr index_t NBPerLoad = 1;
220 constexpr index_t NThdPerWave = 1;
224 constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
239 template <
typename Problem>
244 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
245 constexpr index_t kBlockSize = Problem::kBlockSize;
246 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
247 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
249 constexpr index_t M1 = Problem::VectorLoadSize /
sizeof(ADataType);
250 constexpr index_t M0 = kMPerBlock / M1;
251 constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
252 static_assert(total_pixels % M1 == 0);
253 constexpr index_t K3 = total_pixels / M1;
255 static_assert(kKPack % K3 == 0);
256 constexpr index_t K2 = kKPack / K3;
258 if constexpr(warp_size >= (K2 * M0))
260 constexpr index_t K1 = warp_size / (K2 * M0);
261 constexpr index_t K0 = kBlockSize / warp_size;
274 constexpr index_t K2_m = K2 / K1;
276 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
287 template <
typename Problem>
290 using BlockWarps =
typename Problem::BlockGemmShape::BlockWarps;
291 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
293 std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
294 typename Problem::ADataType,
295 typename Problem::BDataType>;
298 typename Problem::CDataType,
302 Problem::TransposeC>;
304 using BlockWeightPreshufflePolicy =
306 typename Problem::BDataType,
307 typename Problem::CDataType,
324 template <
typename Problem>
328 using WG_ =
typename BlockGemm::WG;
330 constexpr bool TransposeC = Problem::TransposeC;
331 using CLayout =
typename Problem::CLayout;
332 using CWarpDstr =
typename WG_::CWarpDstr;
335 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
337 if constexpr(TransposeC)
341 constexpr index_t NDimY = CWarpDstr::NDimY;
342 constexpr auto c_warp_y_lengths =
343 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
344 static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
351 return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
355 else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
357 if constexpr(TransposeC)
360 return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
366 constexpr index_t NDimY = CWarpDstr::NDimY;
367 constexpr auto c_warp_y_lengths =
368 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
369 static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
376 static_assert(
false,
"Unsupported CLayout!");
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_wp_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_wp_asmem_bsmem_creg_v1.hpp:16
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:109
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:83
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegBlockDistribution()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:240
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeA()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:67
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:75
static CK_TILE_HOST_DEVICE constexpr auto GetKBPerLoad()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:89
UniversalGemmBasePolicy< UniversalWeightPreshufflePipelineAgBgCrPolicy > BasePolicy
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeC()
Get the vector store size for C tensor.
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:325
static CK_TILE_HOST_DEVICE constexpr auto GetBlockWeightPreshuffle()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:288
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:18
static CK_TILE_DEVICE constexpr auto MakeBFlatDramTileDistribution()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:200
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192