layernorm2d_fwd_pipeline_default_policy.hpp Source File

layernorm2d_fwd_pipeline_default_policy.hpp Source File#

Composable Kernel: layernorm2d_fwd_pipeline_default_policy.hpp Source File
layernorm2d_fwd_pipeline_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
13{
14 template <typename Problem>
29
30 template <typename Problem>
44
45 template <typename Problem>
47 {
48 using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
49 typename Problem::ComputeDataType,
50 typename Problem::BlockShape,
51 Problem::Traits::kFastFDiv,
52 Problem::Traits::kWelford>;
53 return BlockNormReduce<P_>{};
54 }
55
56 template <typename Problem>
58 {
59 using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
60 typename Problem::ComputeDataType,
61 typename Problem::BlockShape,
62 Problem::Traits::kFastFDiv,
63 Problem::Traits::kWelford>;
64
66 }
67
68 template <typename Problem>
70 {
71 using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
72 typename Problem::ComputeDataType,
73 typename Problem::BlockShape,
74 Problem::Traits::kFastFDiv,
75 Problem::Traits::kWelford>;
76
78 }
79
80 template <typename Problem>
82 {
83 if constexpr(Problem::kNeedCrossWarpSync)
84 {
85 using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
86 typename Problem::ComputeDataType,
87 typename Problem::BlockShape,
88 Problem::Traits::kFastFDiv,
89 Problem::Traits::kWelford>;
90
91 using block_welford = BlockNormReduce<P_>;
92 using x_block_tile =
95 using mean_var_block_tile =
96 decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
97
100 }
101 else
102 {
103 return 1; // zero size arrays are an extension
104 }
105 }
106};
107} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
Definition block_norm_reduce.hpp:199
Definition block_norm_reduce.hpp:13
Definition block_norm_reduce_problem.hpp:16
Definition block_norm_reduce.hpp:102
Definition layernorm2d_fwd_pipeline_default_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition layernorm2d_fwd_pipeline_default_policy.hpp:81
static CK_TILE_DEVICE constexpr auto MakeGammaBetaBlockTileDistribution()
Definition layernorm2d_fwd_pipeline_default_policy.hpp:31
static CK_TILE_HOST_DEVICE constexpr auto GetBlockNormReduce()
Definition layernorm2d_fwd_pipeline_default_policy.hpp:46
static CK_TILE_HOST_DEVICE constexpr auto GetBlockNormReduceSync()
Definition layernorm2d_fwd_pipeline_default_policy.hpp:57
static CK_TILE_HOST_DEVICE constexpr auto GetBlockNormReduceCrossWarpSync()
Definition layernorm2d_fwd_pipeline_default_policy.hpp:69
static CK_TILE_DEVICE constexpr auto MakeXBlockTileDistribution()
Definition layernorm2d_fwd_pipeline_default_policy.hpp:15
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192