11template <
typename Problem_,
typename Policy_ =
void>
17 static constexpr bool kFastFDiv = Problem::kFastFDiv;
18 static constexpr bool kWelford = Problem::kWelford;
26 template <
typename XDistributedTensor_,
27 typename MeanDistributedTensor_,
28 typename VarDistributedTensor_>
30 MeanDistributedTensor_& mean_tensor,
31 VarDistributedTensor_& var_tensor,
33 const int& max_count_)
38 constexpr auto spans = XDistributedTensor_::get_distributed_spans();
41 if(cur_count_ < max_count_)
45 constexpr auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
46 constexpr auto out_dstr_idx =
make_tuple(dstr_idx_i0);
52 var_tensor(out_dstr_idx),
59 mean_tensor(out_dstr_idx) += x;
60 var_tensor(out_dstr_idx) += x * x;
67 template <
typename XDistributedTensor_>
70 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
76 XDistributedTensor_::get_tile_distribution()
77 .get_static_tile_distribution_encoding(),
85 template <
typename XDistributedTensor_>
87 operator()(
const XDistributedTensor_& x_tensor,
int& cur_count_,
const int& max_count_)
94 (*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
100template <
typename Problem_,
typename Policy_ =
void>
105 static constexpr bool kWelford = Problem::kWelford;
107 template <
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
109 operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor,
int& count)
111 using Dstr =
typename MeanDistributedTensor_::StaticTileDistribution;
112 using DstrEncode =
typename Dstr::DstrEncode;
113 using DstrEncodeDetail =
typename DstrEncode::detail;
115 static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
118 constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
119 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
121 constexpr index_t idim_p_lane = NDimP - 1;
127 constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
128 static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
130 const int original_count = count;
134 auto v_local_mean = mean_tensor.get_thread_buffer()[i];
135 auto v_local_var = var_tensor.get_thread_buffer()[i];
136 auto v_local_count = original_count;
143 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
145 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
147 constexpr index_t lid_over_rid_derivative =
148 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
151 "wrong! only support power of 2 reduction");
160 (
number<lid_over_rid_derivative << istage.value>{}.value);
163 const auto v_remote_mean =
warp_shuffle(v_local_mean, src_lane);
164 const auto v_remote_var =
warp_shuffle(v_local_var, src_lane);
167 const auto v_remote_count =
warp_shuffle(v_local_count, src_lane);
170 welford_merge(v_local_mean,
180 v_local_mean += v_remote_mean;
181 v_local_var += v_remote_var;
187 mean_tensor.get_thread_buffer()(i) = v_local_mean;
188 var_tensor.get_thread_buffer()(i) = v_local_var;
191 count = v_local_count;
197template <
typename Problem_,
typename Policy_ =
void>
203 static constexpr bool kWelford = Problem::kWelford;
204 using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
206 template <
typename MeanDistributedTensor_>
209 constexpr index_t num_reduce_warps = [&]() {
210 using Dstr =
typename MeanDistributedTensor_::StaticTileDistribution;
211 using DstrEncode =
typename Dstr::DstrEncode;
212 using DstrEncodeDetail =
typename DstrEncode::detail;
214 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
216 constexpr index_t idim_p_warp = 0;
220 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
222 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
228 return num_reduce_warps;
232 template <
typename MeanDistributedTensor_>
238 constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
254 return num_warps * 4 * thread_buf_size *
sizeof(float);
257 template <
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
259 VarDistributedTensor_& var_tensor,
263 using DataType =
typename MeanDistributedTensor_::DataType;
264 using Dstr =
typename MeanDistributedTensor_::StaticTileDistribution;
268 static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
271 constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
272 static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
280 const index_t smem_offset = warp_id;
283 if constexpr(num_reduce_warps == 1)
291 local_scratch_[0] =
bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
292 local_scratch_[1] =
bit_cast<float>(var_tensor.get_thread_buffer()[i]);
297 smem_ptr[smem_offset + i * num_warps] = local_scratch_;
303 index_t local_warp_id = warp_id / num_reduce_warps;
304 index_t local_smem_os = local_warp_id * num_reduce_warps;
305 smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
308 all_scratch[i_0 * num_reduce_warps + i_1] =
309 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
318 auto v_local = all_scratch[i_0 * num_reduce_warps];
324 static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
326 const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
333 welford_merge(v_local_mean,
343 v_local_mean += v_remote_mean;
344 v_local_var += v_remote_var;
348 mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
349 var_tensor.get_thread_buffer()(i_0) = v_local_var;
351 count = v_local_count;
360template <
typename BlockShape>
364 using S = BlockShape;
365 index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
366 constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
368 index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
369 index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
370 index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
371 index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
372 return iN0 * S::Vector_N + iN3;
374 using S_ = BlockShape;
375 constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
378 const index_t element_per_row = row_size / S_::Vector_N;
384 index_t _a = lane_id_n < element_per_row ? 1 : 0;
386 lane_id_n += ThreadsPerBlock_N;
388 return cnt * S_::Vector_N;
392template <
typename VarDistributedTensor_,
bool FastFdiv_ = false>
397 using DataType =
typename VarDistributedTensor_::DataType;
400 if(FastFdiv_ && std::is_same_v<DataType, float>)
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:762
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
Definition block_norm_reduce.hpp:361
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition block_norm_reduce.hpp:393
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition utility.hpp:78
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
Definition tile/core/numeric/math.hpp:455
CK_TILE_DEVICE index_t get_thread_id()
Definition arch.hpp:117
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition thread_welford.hpp:11
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_norm_reduce.hpp:199
static constexpr bool kFastFDiv
Definition block_norm_reduce.hpp:202
static constexpr bool kWelford
Definition block_norm_reduce.hpp:203
typename Problem::BlockShape BlockShape
Definition block_norm_reduce.hpp:201
remove_cvref_t< Problem_ > Problem
Definition block_norm_reduce.hpp:200
static CK_TILE_DEVICE constexpr index_t GetReduceWarps()
Definition block_norm_reduce.hpp:207
std::conditional_t< kWelford, fp32x4_t, fp32x2_t > smem_dtype
Definition block_norm_reduce.hpp:204
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count, void *smem)
Definition block_norm_reduce.hpp:258
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_norm_reduce.hpp:233
typename Problem::ComputeDataType ComputeDataType
Definition block_norm_reduce.hpp:16
static CK_TILE_DEVICE auto MakeMeanVarBlockTile()
Definition block_norm_reduce.hpp:68
remove_cvref_t< Problem_ > Problem
Definition block_norm_reduce.hpp:14
static constexpr bool kWelford
Definition block_norm_reduce.hpp:18
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &cur_count_, const int &max_count_)
Definition block_norm_reduce.hpp:29
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, int &cur_count_, const int &max_count_)
Definition block_norm_reduce.hpp:87
CK_TILE_DEVICE constexpr BlockNormReduce()
Definition block_norm_reduce.hpp:20
typename Problem::XDataType XDataType
Definition block_norm_reduce.hpp:15
static constexpr bool kFastFDiv
Definition block_norm_reduce.hpp:17
Definition block_norm_reduce.hpp:102
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count)
Definition block_norm_reduce.hpp:109
static constexpr bool kWelford
Definition block_norm_reduce.hpp:105
remove_cvref_t< Problem_ > Problem
Definition block_norm_reduce.hpp:103
static constexpr bool kFastFDiv
Definition block_norm_reduce.hpp:104
Definition tile/core/numeric/integral_constant.hpp:13
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43