gridwise_2d_reduction_threadwise.hpp Source File

gridwise_2d_reduction_threadwise.hpp Source File#

Composable Kernel: gridwise_2d_reduction_threadwise.hpp Source File
gridwise_2d_reduction_threadwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck {
15
16template <typename GridwiseReduction,
17 bool OutputIndex,
18 bool TransformIndexKtoGlobal,
19 bool HaveIndexInput,
20 typename InDataType,
21 typename OutDataType,
22 typename AccDataType,
23 typename IndexDataType,
24 typename InGridDesc_M_K,
25 typename OutGridDesc_M,
26 typename InElementwiseOperation,
27 typename AccElementwiseOperation>
28__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
29 const OutGridDesc_M out_grid_desc_m,
30 const InElementwiseOperation in_elementwise_op,
31 const AccElementwiseOperation acc_elementwise_op,
32 AccDataType alpha,
33 const InDataType* const __restrict__ p_in_value_global,
34 const IndexDataType* const __restrict__ p_in_index_global,
35 AccDataType beta,
36 OutDataType* const __restrict__ p_out_value_global,
37 IndexDataType* const __restrict__ p_out_index_global)
38{
39 if constexpr(!OutputIndex)
40 {
41 GridwiseReduction::Run(in_grid_desc_m_k,
42 out_grid_desc_m,
43 in_elementwise_op,
44 acc_elementwise_op,
45 alpha,
46 p_in_value_global,
47 beta,
48 p_out_value_global);
49 }
50 else
51 {
52 GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
53 in_grid_desc_m_k,
54 out_grid_desc_m,
55 in_elementwise_op,
56 acc_elementwise_op,
57 alpha,
58 p_in_value_global,
59 p_in_index_global,
60 beta,
61 p_out_value_global,
62 p_out_index_global);
63 };
64};
65
66template <typename InDataType,
67 typename OutDataType,
68 typename AccDataType,
69 typename IndexDataType,
70 typename InGridDesc_M_K,
71 typename OutGridDesc_M,
72 typename ReduceOperation,
73 typename InElementwiseOperation,
74 typename AccElementwiseOperation,
75 InMemoryDataOperationEnum OutMemoryDataOperation,
76 bool PropagateNan,
77 index_t BlockSize,
78 index_t MThreadSliceSize,
79 index_t KThreadSliceSize,
80 index_t InSrcVectorDim,
81 index_t InSrcVectorSize,
82 index_t OutDstVectorSize>
84{
85 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
86 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
87 (MThreadSliceSize % OutDstVectorSize == 0),
88 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
89
92
97
99
100 static constexpr auto I0 = Number<0>{};
101
102 __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
103 const OutGridDesc_M& out_grid_desc_m,
104 const InElementwiseOperation& in_elementwise_op,
105 const AccElementwiseOperation& acc_elementwise_op,
106 AccDataType alpha,
107 const InDataType* const __restrict__ p_in_value_global,
108 AccDataType beta,
109 OutDataType* const __restrict__ p_out_value_global)
110 {
111 using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
114 ReduceOperation,
115 PropagateNan>;
116
117 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
118
119 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
120 p_in_value_global,
121 in_grid_desc_m_k.GetElementSpaceSize(),
122 ReduceOperation::template GetIdentityValue<InDataType>());
124 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
125
127 in_thread_buf;
128
130
131 static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
132
133 const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
134
135 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
136 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
138
139 index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
140
141 auto threadwise_src_val_load =
143 AccDataType,
144 InGridDesc_M_K,
145 decltype(thread_buffer_desc),
146 ThreadBufferLengths,
148 InSrcVectorDim,
149 InSrcVectorSize,
150 1,
151 false>(
152 in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
153
154 constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
155
156 index_t reducedLength = 0;
157 do
158 {
159 threadwise_src_val_load.Run(in_grid_desc_m_k,
160 in_global_val_buf,
161 thread_buffer_desc,
162 make_tuple(I0, I0),
163 in_thread_buf);
164
166 // do element-wise pre-reduction operation
168 constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
169 in_elementwise_op(in_thread_buf(Number<offset>{}),
170 in_thread_buf(Number<offset>{}));
171 });
172 });
173
174 ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
175
176 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
177
178 reducedLength += KThreadSliceSize;
179 } while(reducedLength < toReduceLength);
180
182 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
183
184 accu_value_buf(I) *= alpha;
185 });
186
187 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
188
189 if(!float_equal_zero{}(beta))
190 {
191 auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
192 OutDataType,
193 OutGridDesc_M,
194 decltype(reduced_data_desc),
197 0,
198 1,
199 1,
200 true>(
201 out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
202
204 priorDstValue_buf;
205
206 threadwise_dst_load.Run(out_grid_desc_m,
207 dst_global_buf,
208 reduced_data_desc,
209 make_tuple(I0),
210 priorDstValue_buf);
211
213 accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
214 });
215 };
216
217 auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
218 OutDataType,
219 decltype(reduced_data_desc),
220 OutGridDesc_M,
224 0,
225 OutDstVectorSize,
226 OutMemoryDataOperation,
227 1,
228 false>(
229 out_grid_desc_m,
230 make_multi_index(thread_global_1d_id * MThreadSliceSize),
231 PassThroughOp{});
232
233 threadwise_dst_store.Run(
234 reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
235 };
236
237 template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
238 __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
239 const OutGridDesc_M& out_grid_desc_m,
240 const InElementwiseOperation& in_elementwise_op,
241 const AccElementwiseOperation& acc_elementwise_op,
242 AccDataType alpha,
243 const InDataType* const __restrict__ p_in_value_global,
244 const IndexDataType* const __restrict__ p_in_index_global,
245 AccDataType beta,
246 OutDataType* const __restrict__ p_out_value_global,
247 IndexDataType* const __restrict__ p_out_index_global)
248 {
249 using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
250 IndexDataType,
253 ReduceOperation,
254 PropagateNan>;
255
256 (void)acc_elementwise_op;
257
258 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
259
260 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
261 p_in_value_global,
262 in_grid_desc_m_k.GetElementSpaceSize(),
263 ReduceOperation::template GetIdentityValue<InDataType>());
264 const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
265 p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
266
267 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
268 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
269 auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
270 p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
271
273 in_thread_val_buf;
274
276 IndexDataType,
277 MThreadSliceSize * KThreadSliceSize,
278 true>
279 in_thread_idx_buf;
280
283
285 accu_value_buf(I) = identityVal;
286 accu_index_buf(I) = 0;
287 });
288
289 const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
290
291 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
292 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
294
295 index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
296
297 auto threadwise_src_val_load =
299 AccDataType,
300 InGridDesc_M_K,
301 decltype(thread_buffer_desc),
302 ThreadBufferLengths,
304 InSrcVectorDim,
305 InSrcVectorSize,
306 1,
307 false>(
308 in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
309
310 constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
311
312 index_t indexStart = 0;
313 index_t reducedLength = 0;
314 if constexpr(HaveIndexInput)
315 {
316 auto threadwise_src_idx_load =
318 IndexDataType,
319 InGridDesc_M_K,
320 decltype(thread_buffer_desc),
321 ThreadBufferLengths,
323 InSrcVectorDim,
324 InSrcVectorSize,
325 1,
326 false>(
327 in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
328
329 do
330 {
331 threadwise_src_val_load.Run(in_grid_desc_m_k,
332 in_global_val_buf,
333 thread_buffer_desc,
334 make_tuple(I0, I0),
335 in_thread_val_buf);
336
337 threadwise_src_idx_load.Run(in_grid_desc_m_k,
338 in_global_idx_buf,
339 thread_buffer_desc,
340 make_tuple(I0, I0),
341 in_thread_idx_buf);
342
344 // do element-wise pre-reduction operation
346 constexpr auto offset =
347 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
348
349 in_elementwise_op(in_thread_val_buf(Number<offset>{}),
350 in_thread_val_buf(Number<offset>{}));
351 });
352 });
353
354 ThreadwiseReduceWithIndex::Reduce(
355 in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
356
357 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
358 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
359
360 indexStart += KThreadSliceSize;
361 reducedLength += KThreadSliceSize;
362 } while(reducedLength < toReduceLength);
363 }
364 else
365 {
366 do
367 {
368 threadwise_src_val_load.Run(in_grid_desc_m_k,
369 in_global_val_buf,
370 thread_buffer_desc,
371 make_tuple(I0, I0),
372 in_thread_val_buf);
373
375 // do element-wise pre-reduction operation
377 constexpr auto offset =
378 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
379
380 in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
381
382 in_elementwise_op(in_thread_val_buf(Number<offset>{}),
383 in_thread_val_buf(Number<offset>{}));
384 });
385 });
386
387 ThreadwiseReduceWithIndex::Reduce(
388 in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
389
390 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
391
392 indexStart += KThreadSliceSize;
393 reducedLength += KThreadSliceSize;
394 } while(reducedLength < toReduceLength);
395
396 if constexpr(TransformIndexKtoGlobal)
397 {
399 const auto coord = make_tensor_coordinate(
400 in_grid_desc_m_k,
401 make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
402 accu_index_buf(I)));
403
404 accu_index_buf(I) = coord.GetOffset();
405 });
406 }
407 };
408
409 // for indiced operation, acc_elementwise_op shoud do nothing
411 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
412
413 accu_value_buf(I) *= alpha;
414 });
415
416 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
417
418 if(!float_equal_zero{}(beta))
419 {
420 auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
421 OutDataType,
422 OutGridDesc_M,
423 decltype(reduced_data_desc),
426 0,
427 1,
428 1,
429 false>(
430 out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
431
433 priorDstValue_buf;
434
435 threadwise_dst_load.Run(out_grid_desc_m,
436 out_global_val_buf,
437 reduced_data_desc,
438 make_tuple(I0),
439 priorDstValue_buf);
440
442 accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
443 });
444 };
445
446 auto threadwise_dst_val_store =
448 OutDataType,
449 decltype(reduced_data_desc),
450 OutGridDesc_M,
454 0,
455 OutDstVectorSize,
456 OutMemoryDataOperation,
457 1,
458 false>(
459 out_grid_desc_m,
460 make_multi_index(thread_global_1d_id * MThreadSliceSize),
461 PassThroughOp{});
462
463 auto threadwise_dst_idx_store =
465 IndexDataType,
466 decltype(reduced_data_desc),
467 OutGridDesc_M,
471 0,
472 OutDstVectorSize,
473 OutMemoryDataOperation,
474 1,
475 false>(
476 out_grid_desc_m,
477 make_multi_index(thread_global_1d_id * MThreadSliceSize),
478 PassThroughOp{});
479
480 threadwise_dst_val_store.Run(
481 reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
482
483 threadwise_dst_idx_store.Run(
484 reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf);
485 };
486};
487
488} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:28
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
Definition gridwise_2d_reduction_threadwise.hpp:84
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:238
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_threadwise.hpp:102
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
Definition reduction_functions_threadwise.hpp:65
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_common.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340