15template <
typename XDataType,
16 typename ComputeDataType,
17 typename MeanVarDataType,
18 typename XGridDesc_M_K,
19 typename MeanVarGridDesc_M_KBlock,
29 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
30 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
31 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
81 GetKPerThread(
int k,
int kRaw,
int kGridSize,
int block_k_cluster_id,
int thread_k_cluster_id)
83 bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
85 if(is_rightmost_block)
88 int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
92 int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
99 int delta = thread_max_len - kPerBlockTail;
100 delta =
math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
101 kPerThread += XSrcVectorSize - delta;
115 __device__
static void Run(
const XGridDesc_M_K& x_grid_desc_m_k,
116 const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
117 index_t num_k_block_tile_iteration,
118 const XDataType*
const __restrict__ p_x_global,
119 MeanVarDataType*
const p_mean_global,
120 MeanVarDataType*
const p_variance_global,
121 int32_t*
const p_welford_count_global)
127 MThreadSliceSize * XSrcVectorSize,
140 const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(
I1);
141 const index_t block_m_cluster_id = block_global_id / k_grid_size;
142 const index_t block_k_cluster_id = block_global_id % k_grid_size;
144 const auto thread_cluster_idx =
147 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
148 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
164 block_m_cluster_id *
M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
165 block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize));
168 block_m_cluster_id *
M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
171 auto threadwise_welford_mean_var_store =
175 MeanVarGridDesc_M_KBlock,
184 mean_var_grid_desc_m_kblock, mean_var_count_store_index,
PassThroughOp{});
189 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
192 p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
195 p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
198 int kRaw = x_grid_desc_m_k.GetTransforms()[
I2].GetUpperLengths()[
I0];
199 threadwise_welford.max_count_ =
GetKPerThread(x_grid_desc_m_k.GetLength(
I1),
203 thread_k_cluster_id);
210 for(
index_t k = 0; k < num_k_block_tile_iteration; ++k)
213 threadwise_x_load.Run(x_grid_desc_m_k,
218 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
219 threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
223 int welford_count = 0;
228 int count = threadwise_welford.cur_count_;
232 if constexpr(I == MThreadSliceSize - 1)
233 welford_count = count;
236 if(thread_k_cluster_id == 0)
241 mean_var_grid_desc_m_kblock,
242 mean_global_val_buf);
247 mean_var_grid_desc_m_kblock,
250 if(block_m_cluster_id == 0 && thread_m_cluster_id == 0)
251 p_welford_count_global[block_k_cluster_id] = welford_count;
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition utility/math.hpp:148
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_normalization_splitk_1st.hpp:28
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadBufferLengths_M_1 Sequence< MThreadSliceSize, 1 > ThreadBufferLengths_M_1
Definition gridwise_normalization_splitk_1st.hpp:54
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::I1 static constexpr auto I1
Definition gridwise_normalization_splitk_1st.hpp:36
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadBufferNumber static constexpr auto ThreadBufferNumber
Definition gridwise_normalization_splitk_1st.hpp:78
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::K_BlockTileSize static constexpr index_t K_BlockTileSize
Definition gridwise_normalization_splitk_1st.hpp:75
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::BlockwiseWelford BlockwiseWelford< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, false > BlockwiseWelford
Definition gridwise_normalization_splitk_1st.hpp:66
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::PassThroughOp tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_normalization_splitk_1st.hpp:72
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadClusterArrangeOrder typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_normalization_splitk_1st.hpp:44
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::thread_buffer_desc_m_1 static constexpr auto thread_buffer_desc_m_1
Definition gridwise_normalization_splitk_1st.hpp:55
static __device__ void Run(const XGridDesc_M_K &x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, MeanVarDataType *const p_mean_global, MeanVarDataType *const p_variance_global, int32_t *const p_welford_count_global)
Definition gridwise_normalization_splitk_1st.hpp:115
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::reorder_thread_cluster static constexpr bool reorder_thread_cluster
Definition gridwise_normalization_splitk_1st.hpp:33
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::I0 static constexpr auto I0
Definition gridwise_normalization_splitk_1st.hpp:35
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadReduceSrcDesc_M_K decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_normalization_splitk_1st.hpp:58
static __device__ int GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
Definition gridwise_normalization_splitk_1st.hpp:81
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::M_BlockTileSize static constexpr index_t M_BlockTileSize
Definition gridwise_normalization_splitk_1st.hpp:74
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::K_BlockTileStepSize static constexpr index_t K_BlockTileStepSize
Definition gridwise_normalization_splitk_1st.hpp:76
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::I2 static constexpr auto I2
Definition gridwise_normalization_splitk_1st.hpp:37
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadwiseWelford ThreadwiseWelford< ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_normalization_splitk_1st.hpp:63
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::thread_buffer_desc_m_k static constexpr auto thread_buffer_desc_m_k
Definition gridwise_normalization_splitk_1st.hpp:51
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadBufferLengths_M_K Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition gridwise_normalization_splitk_1st.hpp:50
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadReduceDstDesc_M decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_normalization_splitk_1st.hpp:60
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadBufferDimAccessOrder typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_normalization_splitk_1st.hpp:41
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::thread_cluster_desc static constexpr auto thread_cluster_desc
Definition gridwise_normalization_splitk_1st.hpp:47
ck::GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize >::ThreadClusterLengths_M_K Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_normalization_splitk_1st.hpp:39
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340