18 template <
typename Problem>
23 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
24 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
25 constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
27 static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
28 return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
31 template <
typename Problem>
35 using BlockGemmShape =
typename Problem::BlockGemmShape;
37 constexpr index_t BlockSize = Problem::kBlockSize;
38 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
39 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
40 constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
42 constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
43 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
45 typename Problem::ComputeDataType,
46 typename Problem::CDataType,
52 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
53 if constexpr(PreshuffleQuant)
59 MPerBlock / WarpGemm::kM,
65 return TileEncodingPattern::make_2d_static_tile_distribution();
69 if constexpr(Problem::TransposeC)
71 using TileEncodingPatternTransposeC =
78 return TileEncodingPatternTransposeC::make_2d_static_tile_distribution();
91 return TileEncodingPattern::make_2d_static_tile_distribution();
96 template <
typename Problem>
99 using BlockWarps =
typename Problem::BlockGemmShape::BlockWarps;
100 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
102 static_assert(Problem::QuantGroupSize::kK % WarpTile::at(
I2) == 0,
103 "KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
106 typename Problem::ComputeDataType,
107 typename Problem::CDataType,
111 Problem::TransposeC>;
112 static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
113 std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
114 static_assert(std::is_same_v<typename Problem::CDataType, float>);
116 typename Problem::BDataType,
117 typename Problem::CDataType,
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
int32_t index_t
Definition integer.hpp:9
Definition block_universal_gemm_as_aquant_bs_cr.hpp:56
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:12
UniversalGemmPipelineAgBgCrPolicy Base
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:97
static CK_TILE_HOST_DEVICE constexpr auto MakeAQDramTileDistribution()
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:32
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeAQ()
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:19
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:693
Definition gemm_group_quant_utils.hpp:124
Definition gemm_group_quant_utils.hpp:57