default_2d_epilogue.hpp Source File

default_2d_epilogue.hpp Source File#

Composable Kernel: default_2d_epilogue.hpp Source File
default_2d_epilogue.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// this epilogue just store out a M*N matrix, row major
13
14template <typename AccDataType_,
15 typename ODataType_,
16 bool kPadM_,
17 bool kPadN_,
18 bool UseRawStore_ = true,
21{
24 static constexpr bool kPadM = kPadM_;
25 static constexpr bool kPadN = kPadN_;
26 static constexpr bool UseRawStore = UseRawStore_;
27 static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
28 static constexpr index_t NumDTensor = 0;
29};
30
31template <typename AsDataType_,
32 typename BsDataType_,
33 typename DsDataType_,
34 typename AccDataType_,
35 typename ODataType_,
36 typename DsLayout_,
37 typename CLayout_,
38 typename CDElementwise_,
39 index_t kM_,
40 index_t kN_,
41 bool kPadM_,
42 bool kPadN_,
43 index_t kMPerXdl_,
44 index_t kNPerXdl_,
45 index_t kKPerXdl_,
46 bool isCTransposed_,
47 bool UseRawStore_ = true,
50 ODataType_,
51 kPadM_,
52 kPadN_,
53 UseRawStore_,
54 MemoryOperation_>
55{
62 static constexpr index_t kMPerBlock = kM_;
63 static constexpr index_t kNPerBlock = kN_;
64 static constexpr index_t kMPerXdl = kMPerXdl_;
65 static constexpr index_t kNPerXdl = kNPerXdl_;
66 static constexpr index_t kKPerXdl = kKPerXdl_;
67 static constexpr index_t isCTransposed = isCTransposed_;
68
69 static constexpr index_t NumDTensor = DsDataType::size();
70
71 static_assert(NumDTensor == DsLayout::size(),
72 "The size of DsDataType and DsLayout should be the same");
73};
74
75template <typename Problem_, typename Policy_ = void>
77{
81 static constexpr bool kPadM = Problem::kPadM;
82 static constexpr bool kPadN = Problem::kPadN;
83 static constexpr bool UseRawStore = Problem::UseRawStore;
84 static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
85
86 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
87
88 // TODO: this function assume store out vector size is the same as OAccTile last dimension size
89 // how do we fix this ?
90 template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
91 CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
92 const OAccTile& o_acc_tile,
93 const DsDramWindows& ds_dram_windows,
94 void* = nullptr) const
95 {
96 const auto storeOrUpdateTile = [&](const auto& o_tile) {
97 // TODO: this is ugly
98 if constexpr(UseRawStore && (kPadM || kPadN))
99 {
101 {
102 store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
103 }
104 else
105 {
106 update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
107 }
109 }
110 else
111 {
113 {
114 store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
115 }
116 else
117 {
118 update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
119 }
120 }
121 };
122
123 if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && Problem::NumDTensor >= 1)
124 {
125 using elementwise_result_t = decltype(load_tile(
126 make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
127 make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
128 ds_dram_windows[number<0>{}].get_window_origin(),
129 o_acc_tile.get_tile_distribution())));
130
131 elementwise_result_t elementwise_result;
132
133 const auto d_tensor_tuple = generate_tuple(
134 [&](auto idx) {
135 const auto d_tile_window =
136 make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
137 return load_tile(d_tile_window);
138 },
140
141 const auto c_d_tuple = concat_tuple_of_reference(
142 tie(elementwise_result, o_acc_tile),
143 generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
145
146 tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
147
148 storeOrUpdateTile(elementwise_result);
149 }
150 else
151 {
152 storeOrUpdateTile(o_acc_tile);
153 }
154 }
155};
156
157template <typename Problem_, typename Policy_ = void>
158struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
159{
167
168 using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
171
172 using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
175
178 // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
180 std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
181
186 static constexpr index_t kMPerXdl = Problem::kMPerXdl;
187 static constexpr index_t kNPerXdl = Problem::kNPerXdl;
188 static constexpr index_t kKPerXdl = Problem::kKPerXdl;
189 static constexpr index_t isCTransposed = Problem::isCTransposed;
190
194 kMPerXdl,
195 kNPerXdl,
196 kKPerXdl,
198
199 using CWarpDstr = typename WG::CWarpDstr;
200
201 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
202 {
203 // N is contiguous dimension
204 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
205 {
206 if constexpr(isCTransposed)
207 {
208 // In this case each thread has multiple consecutive elements in
209 // N dimension, however consecutive threads' elements have stride.
210 constexpr index_t NDimY = CWarpDstr::NDimY;
211 constexpr auto c_warp_y_lengths =
212 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
213 static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
214 c_warp_y_lengths.get(number<NDimY - 1>{}));
215 return c_warp_y_lengths.get(number<NDimY - 1>{});
216 }
217 else
218 {
219 // In this case each thread has just a single item in Ndim
220 return (WG::WarpGemmAttribute::Impl::kCNLane *
221 WG::WarpGemmAttribute::Impl::kBNBlock) /
222 WG::kN;
223 }
224 }
225 // M is contiguous dimension
226 else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
227 {
228 if constexpr(isCTransposed)
229 {
230 // In this case each thread has just a single item in Mdim
231 return (WG::WarpGemmAttribute::Impl::kCNLane *
232 WG::WarpGemmAttribute::Impl::kAMBlock) /
233 WG::kN;
234 }
235 else
236 {
237 // In this case each thread has multiple consecutive elements in
238 // M dimension, however consecutive threads' elements have stride.
239 constexpr index_t NDimY = CWarpDstr::NDimY;
240 constexpr auto c_warp_y_lengths =
241 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
242 static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
243 c_warp_y_lengths.get(number<NDimY - 1>{}));
244 return c_warp_y_lengths.get(number<NDimY - 1>{});
245 }
246 }
247 else
248 {
249 static_assert(false, "Unsupported CLayout!");
250 }
251 }
252
253 template <index_t I>
254 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
255 {
256 return GetVectorSizeC();
257 }
258};
259
260} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition tile/core/container/tuple.hpp:376
memory_operation_enum
Definition arch.hpp:56
@ set
Definition arch.hpp:57
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition tile/core/container/tuple.hpp:443
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition tile_elementwise.hpp:71
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:435
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:1063
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:46
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition update_tile.hpp:68
Definition default_2d_epilogue.hpp:77
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition default_2d_epilogue.hpp:86
remove_cvref_t< typename Problem::ODataType > ODataType
Definition default_2d_epilogue.hpp:80
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *=nullptr) const
Definition default_2d_epilogue.hpp:91
static constexpr bool kPadN
Definition default_2d_epilogue.hpp:82
remove_cvref_t< Default2DProblem > Problem
Definition default_2d_epilogue.hpp:78
static constexpr bool kPadM
Definition default_2d_epilogue.hpp:81
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition default_2d_epilogue.hpp:79
static constexpr bool UseRawStore
Definition default_2d_epilogue.hpp:83
static constexpr memory_operation_enum MemoryOperation
Definition default_2d_epilogue.hpp:84
Definition default_2d_epilogue.hpp:21
remove_cvref_t< UnquantYDataType > ODataType
Definition default_2d_epilogue.hpp:23
remove_cvref_t< AccDataType > AccDataType
Definition default_2d_epilogue.hpp:22
static constexpr memory_operation_enum MemoryOperation
Definition default_2d_epilogue.hpp:27
static constexpr index_t NumDTensor
Definition default_2d_epilogue.hpp:28
Definition default_2d_epilogue.hpp:159
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition default_2d_epilogue.hpp:162
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition default_2d_epilogue.hpp:163
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeC()
Definition default_2d_epilogue.hpp:201
static constexpr index_t kMPerXdl
Definition default_2d_epilogue.hpp:186
static constexpr bool ADataTypeIsTuple
Definition default_2d_epilogue.hpp:165
static constexpr index_t kNPerXdl
Definition default_2d_epilogue.hpp:187
static constexpr index_t kKPerXdl
Definition default_2d_epilogue.hpp:188
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > > > AsDataTypeTuple
Definition default_2d_epilogue.hpp:168
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition default_2d_epilogue.hpp:184
static constexpr index_t isCTransposed
Definition default_2d_epilogue.hpp:189
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition default_2d_epilogue.hpp:179
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition default_2d_epilogue.hpp:176
remove_cvref_t< typename Problem::CLayout > CLayout
Definition default_2d_epilogue.hpp:185
remove_cvref_t< typename Problem::ODataType > ODataType
Definition default_2d_epilogue.hpp:164
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition default_2d_epilogue.hpp:182
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition default_2d_epilogue.hpp:161
WarpGemmDispatcher< ADataType, BTypeToUse, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition default_2d_epilogue.hpp:191
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > > > BsDataTypeTuple
Definition default_2d_epilogue.hpp:172
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition default_2d_epilogue.hpp:177
static constexpr bool BDataTypeIsTuple
Definition default_2d_epilogue.hpp:166
typename WG::CWarpDstr CWarpDstr
Definition default_2d_epilogue.hpp:199
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition default_2d_epilogue.hpp:183
remove_cvref_t< Problem_ > Problem
Definition default_2d_epilogue.hpp:160
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeD(number< I > index)
Definition default_2d_epilogue.hpp:254
Definition default_2d_epilogue.hpp:55
remove_cvref_t< DsLayout_ > DsLayout
Definition default_2d_epilogue.hpp:61
static constexpr index_t kKPerXdl
Definition default_2d_epilogue.hpp:66
remove_cvref_t< DsDataType_ > DsDataType
Definition default_2d_epilogue.hpp:59
remove_cvref_t< AsDataType_ > AsDataType
Definition default_2d_epilogue.hpp:56
static constexpr index_t kMPerXdl
Definition default_2d_epilogue.hpp:64
static constexpr index_t kNPerBlock
Definition default_2d_epilogue.hpp:63
static constexpr index_t kNPerXdl
Definition default_2d_epilogue.hpp:65
static constexpr index_t kMPerBlock
Definition default_2d_epilogue.hpp:62
remove_cvref_t< CLayout_ > CLayout
Definition default_2d_epilogue.hpp:58
remove_cvref_t< BsDataType_ > BsDataType
Definition default_2d_epilogue.hpp:57
remove_cvref_t< CDElementwise_ > CDElementwise
Definition default_2d_epilogue.hpp:60
static constexpr index_t isCTransposed
Definition default_2d_epilogue.hpp:67
static constexpr index_t NumDTensor
Definition default_2d_epilogue.hpp:69