gemm_bquant_pipeline_ag_bg_cr_policy.hpp Source File

gemm_bquant_pipeline_ag_bg_cr_policy.hpp Source File#

Composable Kernel: gemm_bquant_pipeline_ag_bg_cr_policy.hpp Source File
gemm_bquant_pipeline_ag_bg_cr_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck_tile {
10
12{
14 using Base::I0;
15 using Base::I1;
16 using Base::I2;
17
18 template <typename Problem>
19 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
20 {
23 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
24 constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
25 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
26 constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
27
28 static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
29 return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
30 }
31
32 template <typename Problem>
34 {
36 using BlockGemmShape = typename Problem::BlockGemmShape;
37
38 constexpr index_t BlockSize = Problem::kBlockSize;
39 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
40 constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
41 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
42 constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
43 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
44 using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
45 typename Problem::ComputeDataType,
46 typename Problem::CDataType,
47 WarpTile::at(I0),
48 WarpTile::at(I1),
49 WarpTile::at(I2),
50 Problem::TransposeC>;
51
52 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
53 using TileEncodingPattern =
55 WarpGemm,
56 BlockSize,
57 KPerBlockBQ,
58 NPerBlockBQ,
59 Problem::QuantGroupSize::kN>;
60
61 return TileEncodingPattern::make_2d_static_tile_distribution();
62 }
63
64 template <typename Problem>
65 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
66 {
67 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
68 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
69
70 static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
71 "KPerWarpGemm must be a multiple of QuantGroupSize!");
72
73 using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
74 typename Problem::ComputeDataType,
75 typename Problem::CDataType,
76 WarpTile::at(I0),
77 WarpTile::at(I1),
78 WarpTile::at(I2),
79 Problem::TransposeC>;
80 static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
81 std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
82 static_assert(std::is_same_v<typename Problem::CDataType, float>);
83 using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
84 typename Problem::BDataType,
85 typename Problem::CDataType,
86 BlockWarps,
87 WarpGemm>;
89 }
90};
91
92} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
int32_t index_t
Definition integer.hpp:9
Definition block_universal_gemm_as_bs_bquant_cr.hpp:56
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:12
UniversalGemmPipelineAgBgCrPolicy Base
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:65
static CK_TILE_HOST_DEVICE constexpr auto MakeBQDramTileDistribution()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:33
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeBQ()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:19
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:693
Definition gemm_group_quant_utils.hpp:176