gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp Source File

gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp Source File
gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12// Default policy for GemmPipelineAgBgCrCompAsync
13// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
14// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
16 : public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
17{
20
21 template <typename Problem>
23 {
24 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
25 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
26 if constexpr(is_a_load_tr<Problem>)
27 {
28 // TODO: better LDS descriptor for performance
29 // This branch is reusing the logic from
30 // UniversalGemmBasePolicy::MakeALdsBlockDescriptor
31 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
35 number<1>{});
36 return a_lds_block_desc_0;
37 }
38 else
39 {
40 constexpr index_t KPack = GetSmemPackA<Problem>();
41
42 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
46 number<1>{});
47
49 a_lds_block_desc_0,
55 }
56 }
57
58 template <typename Problem>
60 {
61 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
62 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
63 if constexpr(is_b_load_tr<Problem>)
64 {
65 // TODO: better LDS descriptor for performance
66 // This branch is reusing the logic from
67 // UniversalGemmBasePolicy::MakeBLdsBlockDescriptor
68 constexpr auto b_lds_block_desc_0 =
72 number<1>{});
73 return b_lds_block_desc_0;
74 }
75 else
76 {
77 constexpr index_t KPack = GetSmemPackB<Problem>();
78
79 constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
83 number<1>{});
84
86 b_lds_block_desc_0,
92 }
93 }
94
95 template <typename Problem>
96 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
97 {
98 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
99 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
100
101 constexpr index_t vector_size =
102 DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
103 constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
104 constexpr auto wg_attr_num_access =
106 : vector_size == thread_elements ? WGAttrNumAccessEnum::Single
107 : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
108 : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
110
111 using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
112 typename Problem::BDataType,
113 typename Problem::CDataType, // AccDataType
114 WarpTile::at(I0),
115 WarpTile::at(I1),
116 WarpTile::at(I2),
117 Problem::TransposeC,
118 false,
119 false,
120 wg_attr_num_access>;
121
122 using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
123 typename Problem::BDataType,
124 typename Problem::CDataType,
125 BlockWarps,
126 WarpGemm>;
127
129 }
130};
131} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Invalid
Definition warp_gemm_attribute_mfma.hpp:17
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
@ Quad
Definition warp_gemm_attribute_mfma.hpp:16
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
@ warp_raked
Warp raked pattern.
Definition static_encoding_pattern.hpp:99
int32_t index_t
Definition integer.hpp:9
constexpr int DS_READ_TR_SIZE()
Definition load_tile_transpose.hpp:20
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr auto MakeBLdsBlockDescriptor()
Definition gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp:59
static constexpr auto BTileAccessPattern
Definition gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp:19
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp:96
static constexpr auto ATileAccessPattern
Definition gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp:18
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp:22
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:645
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:653
static constexpr bool is_a_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:44
static constexpr bool is_b_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:46
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition tile/core/container/sequence.hpp:49