gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp Source File

gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp Source File
gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12// Default policy for GemmPipelineAGmemBGmemCregComputeV6, except the block gemm method, it shares
13// the same vector size implementation, SmemSize, Global memory tile distiribution as the
14// UniversalGemm Pipeline Policy.
15// Default policy class should not be templated, put template on
16// member functions instead.
18 : public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV6DefaultPolicy>
19{
20 template <typename Problem>
21 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
22 {
23 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
24 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
25
26 constexpr index_t vector_size =
27 DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
28 constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
29 constexpr auto wg_attr_num_access =
31 : vector_size == thread_elements ? WGAttrNumAccessEnum::Single
32 : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
33 : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
35
36 using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
37 typename Problem::BDataType,
38 typename Problem::CDataType,
39 WarpTile::at(I0),
40 WarpTile::at(I1),
41 WarpTile::at(I2),
42 Problem::TransposeC,
43 false,
44 false,
45 wg_attr_num_access>;
46
47 using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
48 typename Problem::BDataType,
49 typename Problem::CDataType,
50 BlockWarps,
51 WarpGemm>;
52
54 }
55};
56} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Invalid
Definition warp_gemm_attribute_mfma.hpp:17
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
@ Quad
Definition warp_gemm_attribute_mfma.hpp:16
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
int32_t index_t
Definition integer.hpp:9
constexpr int DS_READ_TR_SIZE()
Definition load_tile_transpose.hpp:20
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp:19
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp:21
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr bool is_a_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:44
static constexpr bool is_b_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:46
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49