17template <
typename GridwiseReduction,
21 typename InGridDesc_M_K,
22 typename DsGridDesc_M,
23 typename OutGridDesc_M,
24 typename InElementwiseOperation,
25 typename OutElementwiseOperation,
26 typename DsGridPointer>
29 const DsGridDesc_M ds_grid_desc_m,
30 const OutGridDesc_M out_grid_desc_m,
31 const InElementwiseOperation in_elementwise_op,
32 const OutElementwiseOperation out_elementwise_op,
33 const InDataType*
const __restrict__ p_in_value_global,
34 const DsGridPointer p_ds_value_global,
35 OutDataType*
const __restrict__ p_out_value_global)
37 GridwiseReduction::Run(in_grid_desc_m_k,
47template <
typename InDataType,
51 typename InGridDesc_M_K,
52 typename DsGridDesc_M,
53 typename OutGridDesc_M,
54 typename ReduceOperation,
55 typename InElementwiseOperation,
56 typename OutElementwiseOperation,
64 typename DsVectorSize>
67 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
68 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
69 (MThreadSliceSize % OutDstVectorSize == 0),
70 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
93 return static_cast<const DDataType*
>(
nullptr);
100 __device__
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
101 const DsGridDesc_M& ds_grid_desc_m,
102 const OutGridDesc_M& out_grid_desc_m,
103 const InElementwiseOperation& in_elementwise_op,
104 const OutElementwiseOperation& out_elementwise_op,
105 const InDataType*
const __restrict__ p_in_value_global,
107 OutDataType*
const __restrict__ p_out_value_global)
115 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
119 in_grid_desc_m_k.GetElementSpaceSize(),
120 ReduceOperation::template GetIdentityValue<InDataType>());
122 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
131 const auto toReduceLength = in_grid_desc_m_k.GetLength(
Number<1>{});
139 auto threadwise_src_val_load =
143 decltype(thread_buffer_desc),
150 in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
157 threadwise_src_val_load.Run(in_grid_desc_m_k,
166 constexpr auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
172 ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
174 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
176 reducedLength += KThreadSliceSize;
177 }
while(reducedLength < toReduceLength);
193 p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize());
204 decltype(ds_grid_desc_m[I]),
205 decltype(reduced_data_desc),
212 ds_grid_desc_m[I],
make_multi_index(thread_global_1d_id * MThreadSliceSize)};
217 ds_global_load(I).Run(ds_grid_desc_m[I],
230 tie(accu_value_buf[I]),
231 generate_tie([&](
auto Id) ->
const auto& {
return ds_thread_buf[Id][I]; },
234 unpack2(out_elementwise_op,
tie(out_value_buf(I)), c_ds_buf_refs);
240 decltype(reduced_data_desc),
247 OutMemoryDataOperation,
254 threadwise_dst_store.Run(
255 reduced_data_desc,
make_tuple(
I0), out_value_buf, out_grid_desc_m, dst_global_buf);
__global__ void kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k, const DsGridDesc_M ds_grid_desc_m, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op, const InDataType *const __restrict__ p_in_value_global, const DsGridPointer p_ds_value_global, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:28
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:66
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::NumDTensor static constexpr index_t NumDTensor
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:84
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::ThreadReduceSrcDesc_M_K decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:75
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::I0 static constexpr auto I0
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:82
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:98
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::ThreadBufferDimAccessOrder typename conditional< InSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:72
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::ThreadReduceDstDesc_M decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:77
ck::GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence >::PassThrough tensor_operation::element_wise::PassThrough PassThrough
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:80
static constexpr auto MakeDsGridPointer()
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:87
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const DsGridDesc_M &ds_grid_desc_m, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const OutElementwiseOperation &out_elementwise_op, const InDataType *const __restrict__ p_in_value_global, const DsGridPointer p_ds_grid, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:100
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340