layernorm2d_fwd_pipeline_one_pass.hpp Source File#
layernorm2d_fwd_pipeline_one_pass.hpp
Go to the documentation of this file.
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
Definition block_norm_reduce.hpp:361
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition block_norm_reduce.hpp:393
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
@ SMOOTH_DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:42
@ DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:43
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
@ PRE_ADD_STORE
Definition layernorm2d_fwd_traits.hpp:27
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition layernorm2d_fwd_pipeline_one_pass.hpp:16
static constexpr auto kFusedQuant
Definition layernorm2d_fwd_pipeline_one_pass.hpp:44
XDataType YResidualDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:30
ck_tile::remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:23
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:25
ck_tile::remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:26
static constexpr bool kHasBeta
Definition layernorm2d_fwd_pipeline_one_pass.hpp:33
ck_tile::remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:21
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:22
ck_tile::remove_cvref_t< Policy_ > Policy
Definition layernorm2d_fwd_pipeline_one_pass.hpp:18
static constexpr auto kFusedAdd
Definition layernorm2d_fwd_pipeline_one_pass.hpp:43
static constexpr const char * name
Definition layernorm2d_fwd_pipeline_one_pass.hpp:46
static constexpr auto kXbias
Definition layernorm2d_fwd_pipeline_one_pass.hpp:42
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:20
ck_tile::remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:27
static constexpr bool kFastFDiv
Definition layernorm2d_fwd_pipeline_one_pass.hpp:40
XDataType XResidualDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:29
static constexpr bool kPadN
Definition layernorm2d_fwd_pipeline_one_pass.hpp:39
static constexpr bool kSaveMean
Definition layernorm2d_fwd_pipeline_one_pass.hpp:34
static constexpr bool kNeedCrossWarpSync
Definition layernorm2d_fwd_pipeline_one_pass.hpp:37
static constexpr bool kPadM
Definition layernorm2d_fwd_pipeline_one_pass.hpp:38
ck_tile::remove_cvref_t< Problem_ > Problem
Definition layernorm2d_fwd_pipeline_one_pass.hpp:17
static constexpr bool kHasGamma
Definition layernorm2d_fwd_pipeline_one_pass.hpp:32
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition layernorm2d_fwd_pipeline_one_pass.hpp:53
static constexpr bool kSaveInvStd
Definition layernorm2d_fwd_pipeline_one_pass.hpp:35
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const XBiasWindow &x_bias_window_, const GammaWindow &gamma_window_, const BetaWindow &beta_window_, YWindow &y_window_, const YResidualWindow &y_residual_window_, MeanWindow &mean_window, InvStdWindow &inv_std_window, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition layernorm2d_fwd_pipeline_one_pass.hpp:70
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:24
static constexpr bool kWelford
Definition layernorm2d_fwd_pipeline_one_pass.hpp:41
Definition tile/core/numeric/integral_constant.hpp:13