gridwise_2d_reduction_multiblock.hpp Source File

gridwise_2d_reduction_multiblock.hpp Source File#

Composable Kernel: gridwise_2d_reduction_multiblock.hpp Source File
gridwise_2d_reduction_multiblock.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck {
15
16template <typename GridwiseReduction,
17 bool OutputIndex,
18 bool HaveIndexInput,
19 typename InDataType,
20 typename OutDataType,
21 typename AccDataType,
22 typename IndexDataType,
23 typename InGridDesc_M_K,
24 typename OutGridDesc_M,
25 typename InElementwiseOperation,
26 typename AccElementwiseOperation>
27__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
28 const OutGridDesc_M out_grid_desc_m,
29 const InElementwiseOperation in_elementwise_op,
30 const AccElementwiseOperation acc_elementwise_op,
31 index_t block_group_size,
32 index_t num_k_block_tile_iteration,
33 AccDataType alpha,
34 const InDataType* const __restrict__ p_in_value_global,
35 const IndexDataType* const __restrict__ p_in_index_global,
36 AccDataType beta,
37 OutDataType* const __restrict__ p_out_value_global,
38 IndexDataType* const __restrict__ p_out_index_global)
39{
40 if constexpr(!OutputIndex)
41 {
42 (void)p_in_index_global;
43 (void)p_out_index_global;
44
45 GridwiseReduction::Run(in_grid_desc_m_k,
46 out_grid_desc_m,
47 in_elementwise_op,
48 acc_elementwise_op,
49 block_group_size,
50 num_k_block_tile_iteration,
51 alpha,
52 p_in_value_global,
53 beta,
54 p_out_value_global);
55 }
56 else
57 {
58 GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
59 out_grid_desc_m,
60 in_elementwise_op,
61 acc_elementwise_op,
62 num_k_block_tile_iteration,
63 alpha,
64 p_in_value_global,
65 p_in_index_global,
66 beta,
67 p_out_value_global,
68 p_out_index_global);
69 };
70};
71
72template <typename InDataType,
73 typename OutDataType,
74 typename AccDataType,
75 typename IndexDataType,
76 typename InGridDesc_M_K,
77 typename OutGridDesc_M,
78 typename ReduceOperation,
79 typename InElementwiseOperation,
80 typename AccElementwiseOperation,
81 InMemoryDataOperationEnum OutMemoryDataOperation,
82 bool PropagateNan,
83 index_t BlockSize,
84 index_t MThreadClusterSize,
85 index_t KThreadClusterSize,
86 index_t MThreadSliceSize,
87 index_t KThreadSliceSize,
88 index_t InSrcVectorDim,
89 index_t InSrcVectorSize,
90 index_t OutDstVectorSize>
92{
93 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
94 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
95 (MThreadSliceSize % OutDstVectorSize == 0),
96 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
97
98 static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
99
101
104
107
108 static constexpr auto thread_cluster_desc =
110
115
117 BlockSize,
120 ReduceOperation,
121 PropagateNan>;
122
126 ReduceOperation,
127 PropagateNan>;
128
130
131 static constexpr auto I0 = Number<0>{};
132 static constexpr auto I1 = Number<1>{};
133
134 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
135 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
136
138
139 __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
140 const OutGridDesc_M& out_grid_desc_m,
141 const InElementwiseOperation& in_elementwise_op,
142 const AccElementwiseOperation& acc_elementwise_op,
143 index_t block_group_size,
144 index_t num_k_block_tile_iteration,
145 AccDataType alpha,
146 const InDataType* const __restrict__ p_in_value_global,
147 AccDataType beta,
148 OutDataType* const __restrict__ p_out_value_global)
149 {
150 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
151
152 // LDS
153 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
154
155 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
156 p_in_value_global,
157 in_grid_desc_m_k.GetElementSpaceSize(),
158 ReduceOperation::template GetIdentityValue<InDataType>());
159 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
160 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
161
162 auto reduce_work_buf =
163 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
164
166 in_thread_buf;
167
169
170 static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
171
172 const index_t thread_local_id = get_thread_local_1d_id();
173 const index_t block_global_id = get_block_1d_id();
174 const index_t blkgroup_id = block_global_id / block_group_size;
175 const index_t block_local_id = block_global_id % block_group_size;
176
177 const auto thread_cluster_idx =
178 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
179
180 const auto thread_m_cluster_id = thread_cluster_idx[I0];
181 const auto thread_k_cluster_id = thread_cluster_idx[I1];
182
183 const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
184
185 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
186 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
188
189 auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
190 AccDataType,
191 InGridDesc_M_K,
192 decltype(thread_buffer_desc),
193 ThreadBufferLengths,
195 InSrcVectorDim,
196 InSrcVectorSize,
197 1,
198 false>(
199 in_grid_desc_m_k,
200 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
201 block_local_id * reduceSizePerBlock +
202 thread_k_cluster_id * KThreadSliceSize));
203
204 constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
205
206 index_t reducedTiles = 0;
207 do
208 {
209 threadwise_src_load.Run(in_grid_desc_m_k,
210 in_global_val_buf,
211 thread_buffer_desc,
212 make_tuple(I0, I0),
213 in_thread_buf);
214
216 // do element-wise pre-reduction operation
218 constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
219 in_elementwise_op(in_thread_buf(Number<offset>{}),
220 in_thread_buf(Number<offset>{}));
221 });
222 });
223
224 ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
225
226 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
227
228 reducedTiles++;
229 } while(reducedTiles < num_k_block_tile_iteration);
230
231 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
232
234 [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
235
237 if(thread_k_cluster_id == 0)
238 {
239 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
240
241 accu_value_buf(I) *= alpha;
242 }
243 });
244
245 if(thread_k_cluster_id == 0)
246 {
247 if(!float_equal_zero{}(beta))
248 {
250 priorDstValueBuf;
251
252 auto threadwise_dst_load =
254 OutDataType,
255 OutGridDesc_M,
256 decltype(reduced_data_desc),
259 0,
260 OutDstVectorSize,
261 1,
262 false>(
263 out_grid_desc_m,
264 make_multi_index(blkgroup_id * M_BlockTileSize +
265 thread_m_cluster_id * MThreadSliceSize));
266
267 threadwise_dst_load.Run(out_grid_desc_m,
268 out_global_val_buf,
269 reduced_data_desc,
270 make_tuple(I0),
271 priorDstValueBuf);
272
274 accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
275 });
276 };
277
278 auto threadwise_dst_store =
280 OutDataType,
281 decltype(reduced_data_desc),
282 OutGridDesc_M,
286 0,
287 OutDstVectorSize,
288 OutMemoryDataOperation,
289 1,
290 true>(
291 out_grid_desc_m,
292 make_multi_index(blkgroup_id * M_BlockTileSize +
293 thread_m_cluster_id * MThreadSliceSize),
294 PassThroughOp{});
295
296 threadwise_dst_store.Run(reduced_data_desc,
297 make_tuple(I0),
298 accu_value_buf,
299 out_grid_desc_m,
300 out_global_val_buf);
301 }
302 };
303
304 template <bool HaveIndexInput>
305 __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
306 const OutGridDesc_M& out_grid_desc_m,
307 const InElementwiseOperation in_elementwise_op,
308 const AccElementwiseOperation acc_elementwise_op,
309 index_t num_k_block_tile_iteration,
310 AccDataType alpha,
311 const InDataType* const __restrict__ p_in_value_global,
312 const IndexDataType* const __restrict__ p_in_index_global,
313 AccDataType beta,
314 OutDataType* const __restrict__ p_out_value_global,
315 IndexDataType* const __restrict__ p_out_index_global)
316 {
317 using BlockwiseReduceWithIndex =
319 IndexDataType,
320 BlockSize,
323 ReduceOperation,
324 PropagateNan>;
325
326 using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
327 ReduceOperation,
328 AccDataType,
329 IndexDataType>;
330
331 (void)in_elementwise_op;
332
333 // LDS
334 __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
335 __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
336
337 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
338
339 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
340 p_in_value_global,
341 in_grid_desc_m_k.GetElementSpaceSize(),
342 ReduceOperation::template GetIdentityValue<InDataType>());
343 const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
344 p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
345 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
346 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
347 auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
348 p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
349
350 auto reduce_work_val_buf =
351 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
352 auto reduce_work_idx_buf =
353 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
354
356 in_thread_val_buf;
357
359 IndexDataType,
360 MThreadSliceSize * KThreadSliceSize,
361 true>
362 in_thread_idx_buf;
363
366
367 const index_t thread_local_id = get_thread_local_1d_id();
368 const index_t block_global_1d_id = get_block_1d_id();
369
370 const auto thread_cluster_idx =
371 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
372
373 const auto thread_m_cluster_id = thread_cluster_idx[I0];
374 const auto thread_k_cluster_id = thread_cluster_idx[I1];
375
376 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
377 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
379
380 auto threadwise_src_val_load =
382 AccDataType,
383 InGridDesc_M_K,
384 decltype(thread_buffer_desc),
385 ThreadBufferLengths,
387 InSrcVectorDim,
388 InSrcVectorSize,
389 1,
390 false>(
391 in_grid_desc_m_k,
392 make_multi_index(block_global_1d_id * M_BlockTileSize +
393 thread_m_cluster_id * MThreadSliceSize,
394 thread_k_cluster_id * KThreadSliceSize));
395
397 accu_value_buf(I) = identityVal;
398 accu_index_buf(I) = 0;
399 });
400
401 constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
402
403 index_t reducedTiles = 0;
404
405 if constexpr(HaveIndexInput)
406 {
407 auto threadwise_src_idx_load =
409 IndexDataType,
410 InGridDesc_M_K,
411 decltype(thread_buffer_desc),
412 ThreadBufferLengths,
414 InSrcVectorDim,
415 InSrcVectorSize,
416 1,
417 false>(
418 in_grid_desc_m_k,
419 make_multi_index(block_global_1d_id * M_BlockTileSize +
420 thread_m_cluster_id * MThreadSliceSize,
421 thread_k_cluster_id * KThreadSliceSize));
422
423 do
424 {
425 // load the thread slice
426 threadwise_src_val_load.Run(in_grid_desc_m_k,
427 in_global_val_buf,
428 thread_buffer_desc,
429 make_tuple(I0, I0),
430 in_thread_val_buf);
431 threadwise_src_idx_load.Run(in_grid_desc_m_k,
432 in_global_idx_buf,
433 thread_buffer_desc,
434 make_tuple(I0, I0),
435 in_thread_idx_buf);
436
438 AccDataType tmpValue = identityVal;
439 IndexDataType tmpIndex = 0;
440
442 constexpr auto offset =
443 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
444
445 AccumulationWithIndex::Calculate(tmpValue,
446 in_thread_val_buf[Number<offset>{}],
447 tmpIndex,
448 in_thread_idx_buf[Number<offset>{}]);
449 });
450
451 BlockwiseReduceWithIndex::Reduce(
452 reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
453
454 AccumulationWithIndex::Calculate(
455 accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
456 });
457
458 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
459 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
460
461 reducedTiles++;
462 } while(reducedTiles < num_k_block_tile_iteration);
463 }
464 else
465 {
466 index_t indexOffset = 0;
467
468 do
469 {
470 // load the thread slice
471 threadwise_src_val_load.Run(in_grid_desc_m_k,
472 in_global_val_buf,
473 thread_buffer_desc,
474 make_tuple(I0, I0),
475 in_thread_val_buf);
476
479 constexpr auto offset =
480 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
481
482 // initialize the indices for the per-thread to-reduce values
483 in_thread_idx_buf(Number<offset>{}) =
484 indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
485
486 // do element-wise pre-reduction operation
487 in_elementwise_op(in_thread_val_buf(Number<offset>{}),
488 in_thread_val_buf(Number<offset>{}));
489 });
490
491 AccDataType tmpValue = identityVal;
492 IndexDataType tmpIndex = 0;
493
495 constexpr auto offset =
496 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
497
498 AccumulationWithIndex::Calculate(tmpValue,
499 in_thread_val_buf[Number<offset>{}],
500 tmpIndex,
501 in_thread_idx_buf[Number<offset>{}]);
502 });
503
504 BlockwiseReduceWithIndex::Reduce(
505 reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
506
507 AccumulationWithIndex::Calculate(
508 accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
509 });
510
511 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
512
513 indexOffset += K_BlockTileSize;
514 reducedTiles++;
515 } while(reducedTiles < num_k_block_tile_iteration);
516 };
517
518 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
519
521 if(thread_k_cluster_id == 0)
522 {
523 // for indiced operation, acc_elementwise_op shoud do nothing
524 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
525
526 accu_value_buf(I) *= alpha;
527 }
528 });
529
530 if(thread_k_cluster_id == 0)
531 {
532 if(!float_equal_zero{}(beta))
533 {
535 priorDstValueBuf;
536
537 auto threadwise_dst_load =
539 OutDataType,
540 OutGridDesc_M,
541 decltype(reduced_data_desc),
544 0,
545 OutDstVectorSize,
546 1,
547 true>(
548 out_grid_desc_m,
549 make_multi_index(block_global_1d_id * M_BlockTileSize +
550 thread_m_cluster_id * MThreadSliceSize));
551
552 threadwise_dst_load.Run(out_grid_desc_m,
553 out_global_val_buf,
554 reduced_data_desc,
555 make_tuple(I0),
556 priorDstValueBuf);
557
559 accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
560 });
561 };
562
563 auto threadwise_dst_val_store =
565 OutDataType,
566 decltype(reduced_data_desc),
567 OutGridDesc_M,
571 0,
572 OutDstVectorSize,
574 1,
575 true>(
576 out_grid_desc_m,
577 make_multi_index(block_global_1d_id * M_BlockTileSize +
578 thread_m_cluster_id * MThreadSliceSize),
579 PassThroughOp{});
580
581 auto threadwise_dst_idx_store =
583 IndexDataType,
584 decltype(reduced_data_desc),
585 OutGridDesc_M,
589 0,
590 OutDstVectorSize,
592 1,
593 true>(
594 out_grid_desc_m,
595 make_multi_index(block_global_1d_id * M_BlockTileSize +
596 thread_m_cluster_id * MThreadSliceSize),
597 PassThroughOp{});
598
599 threadwise_dst_val_store.Run(reduced_data_desc,
600 make_tuple(I0),
601 accu_value_buf,
602 out_grid_desc_m,
603 out_global_val_buf);
604 threadwise_dst_idx_store.Run(reduced_data_desc,
605 make_tuple(I0),
606 accu_index_buf,
607 out_grid_desc_m,
608 out_global_idx_buf);
609 }
610 };
611};
612
613} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_multiblock.hpp:27
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_2d_reduction_multiblock.hpp:92
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_2d_reduction_multiblock.hpp:111
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_2d_reduction_multiblock.hpp:105
static constexpr bool reorder_thread_cluster
Definition gridwise_2d_reduction_multiblock.hpp:98
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, ReduceOperation, PropagateNan > ThreadwiseReduce
Definition gridwise_2d_reduction_multiblock.hpp:123
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_multiblock.hpp:139
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_2d_reduction_multiblock.hpp:102
static constexpr index_t M_BlockTileSize
Definition gridwise_2d_reduction_multiblock.hpp:134
static constexpr auto I0
Definition gridwise_2d_reduction_multiblock.hpp:131
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_2d_reduction_multiblock.hpp:100
static constexpr auto thread_cluster_desc
Definition gridwise_2d_reduction_multiblock.hpp:108
static constexpr auto I1
Definition gridwise_2d_reduction_multiblock.hpp:132
PartitionedBlockwiseReduction< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, ReduceOperation, PropagateNan > BlockwiseReduce
Definition gridwise_2d_reduction_multiblock.hpp:116
detail::AccumulateWithNanCheck< PropagateNan, ReduceOperation, AccDataType > Accumulation
Definition gridwise_2d_reduction_multiblock.hpp:137
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_multiblock.hpp:305
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_2d_reduction_multiblock.hpp:129
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_2d_reduction_multiblock.hpp:113
static constexpr index_t K_BlockTileSize
Definition gridwise_2d_reduction_multiblock.hpp:135
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition reduction_functions_blockwise.hpp:175
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_functions_accumulate.hpp:65
Definition reduction_functions_accumulate.hpp:28
Definition reduction_common.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340