gridwise_2d_multiple_reduction_threadwise.hpp Source File

gridwise_2d_multiple_reduction_threadwise.hpp Source File#

Composable Kernel: gridwise_2d_multiple_reduction_threadwise.hpp Source File
gridwise_2d_multiple_reduction_threadwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck {
15
16template <typename GridwiseMultipleReduction,
17 index_t NumReduction,
18 typename InDataType,
19 typename OutDataTypePointerTuple,
20 typename AccDataType,
21 typename InGridDesc_M_K,
22 typename OutGridDesc_M_Tuple,
23 typename InElementwiseOperationTuple,
24 typename AccElementwiseOperationTuple>
25__global__ void
26kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
27 const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
28 const InElementwiseOperationTuple in_elementwise_op_tuple,
29 const AccElementwiseOperationTuple acc_elementwise_op_tuple,
31 const InDataType* const __restrict__ p_in_value_global,
33 OutDataTypePointerTuple p_out_value_global_tuple)
34{
35 GridwiseMultipleReduction::Run(in_grid_desc_m_k,
36 out_grid_desc_m_tuple,
37 in_elementwise_op_tuple,
38 acc_elementwise_op_tuple,
39 alpha_values,
40 p_in_value_global,
41 beta_values,
42 p_out_value_global_tuple);
43};
44
45template <index_t NumReduction,
46 typename InDataType,
47 typename OutDataTypePointerTuple,
48 typename AccDataType,
49 typename InGridDesc_M_K,
50 typename OutGridDesc_M_Tuple,
51 typename ReduceOperation,
52 typename InElementwiseOperationTuple,
53 typename AccElementwiseOperationTuple,
54 InMemoryDataOperationEnum OutMemoryDataOperation,
55 bool PropagateNan,
56 index_t BlockSize,
57 index_t MThreadSliceSize,
58 index_t KThreadSliceSize,
59 index_t InSrcVectorDim,
60 index_t InSrcVectorSize,
61 typename OutDstVectorSizeSeq>
63{
64 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
65 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
66 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
67
68 static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
69 NumReduction == OutGridDesc_M_Tuple::Size() &&
70 NumReduction == OutDstVectorSizeSeq::Size() &&
71 NumReduction == InElementwiseOperationTuple::Size() &&
72 NumReduction == AccElementwiseOperationTuple::Size(),
73 "All tuple should have the same size as the number of Reductions!");
74
75 static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
76
79
84
88 ReduceOperation,
89 PropagateNan>;
90
92
93 static constexpr auto I0 = Number<0>{};
94
96
97 __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
98 const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
99 const InElementwiseOperationTuple& in_elementwise_op_tuple,
100 const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
102 const InDataType* const __restrict__ p_in_value_global,
104 OutDataTypePointerTuple p_out_value_global_tuple)
105 {
106 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
107
108 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
109 p_in_value_global,
110 in_grid_desc_m_k.GetElementSpaceSize(),
111 ReduceOperation::template GetIdentityValue<InDataType>());
112 auto out_global_val_buf_tuple = generate_tuple(
113 [&](auto iR) {
115 p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
116 },
118
120 in_thread_buf;
121
122 auto in_thread_buf_tuple = generate_tuple(
123 [&](auto iR) {
124 (void)iR;
126 AccDataType,
127 MThreadSliceSize * KThreadSliceSize,
128 true>{};
129 },
131
132 auto accu_value_buf_tuple = generate_tuple(
133 [&](auto iR) {
134 (void)iR;
136 },
138
139 static_for<0, NumReduction, 1>{}([&](auto iR) {
141 [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
142 });
143
144 const index_t thread_global_1d_id = get_thread_global_1d_id();
145
146 const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
147
148 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
149 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
151
152 auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
153 AccDataType,
154 InGridDesc_M_K,
155 decltype(thread_buffer_desc),
156 ThreadBufferLengths,
158 InSrcVectorDim,
159 InSrcVectorSize,
160 1,
161 false>(
162 in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
163
164 constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
165
166 index_t reducedLength = 0;
167 do
168 {
169 threadwise_src_load.Run(in_grid_desc_m_k,
170 in_global_val_buf,
171 thread_buffer_desc,
172 make_tuple(I0, I0),
173 in_thread_buf);
174
175 static_for<0, NumReduction, 1>{}([&](auto iR) {
177 // do element-wise pre-reduction operation
179 constexpr auto offset =
180 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
181 in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
182 in_thread_buf(Number<offset>{}));
183 });
184 });
185
186 ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
187 });
188
189 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
190
191 reducedLength += KThreadSliceSize;
192 } while(reducedLength < toReduceLength);
193
194 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
195
196 static_for<0, NumReduction, 1>{}([&](auto iR) {
197 using OutDataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
199
201 acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
202 accu_value_buf_tuple(iR)(I));
203
204 accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
205 });
206
207 if(!float_equal_zero{}(beta_values[iR]))
208 {
210 priorDstValueBuf;
211
212 auto threadwise_dst_load =
214 OutDataType,
215 decltype(out_grid_desc_m_tuple[iR]),
216 decltype(reduced_data_desc),
219 0,
220 OutDstVectorSizeSeq::At(iR),
221 1,
222 false>(
223 out_grid_desc_m_tuple[iR],
224 make_multi_index(thread_global_1d_id * MThreadSliceSize));
225
226 threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
227 out_global_val_buf_tuple(iR),
228 reduced_data_desc,
229 make_tuple(I0),
230 priorDstValueBuf);
231
233 accu_value_buf_tuple(iR)(I) +=
234 type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
235 });
236 };
237
238 auto threadwise_dst_store =
240 OutDataType,
241 decltype(reduced_data_desc),
242 decltype(out_grid_desc_m_tuple[iR]),
246 0,
247 OutDstVectorSizeSeq::At(iR),
248 OutMemoryDataOperation,
249 1,
250 true>(
251 out_grid_desc_m_tuple[iR],
252 make_multi_index(thread_global_1d_id * MThreadSliceSize),
253 PassThroughOp{});
254
255 threadwise_dst_store.Run(reduced_data_desc,
256 make_tuple(I0),
257 accu_value_buf_tuple[iR],
258 out_grid_desc_m_tuple[iR],
259 out_global_val_buf_tuple(iR));
260 });
261 };
262};
263
264} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
__global__ void kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition gridwise_2d_multiple_reduction_threadwise.hpp:26
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition utility/array.hpp:14
Definition gridwise_2d_multiple_reduction_threadwise.hpp:63
detail::AccumulateWithNanCheck< PropagateNan, ReduceOperation, AccDataType > Accumulation
Definition gridwise_2d_multiple_reduction_threadwise.hpp:95
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_2d_multiple_reduction_threadwise.hpp:77
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, ReduceOperation, PropagateNan > ThreadwiseReduce
Definition gridwise_2d_multiple_reduction_threadwise.hpp:85
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M_Tuple &out_grid_desc_m_tuple, const InElementwiseOperationTuple &in_elementwise_op_tuple, const AccElementwiseOperationTuple &acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition gridwise_2d_multiple_reduction_threadwise.hpp:97
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_2d_multiple_reduction_threadwise.hpp:82
static constexpr bool reorder_thread_cluster
Definition gridwise_2d_multiple_reduction_threadwise.hpp:75
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_2d_multiple_reduction_threadwise.hpp:91
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_2d_multiple_reduction_threadwise.hpp:80
static constexpr auto I0
Definition gridwise_2d_multiple_reduction_threadwise.hpp:93
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_functions_accumulate.hpp:28
Definition reduction_common.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340