device_sparse_embeddings_forward_layernorm.hpp Source File

device_sparse_embeddings_forward_layernorm.hpp Source File#

Composable Kernel: device_sparse_embeddings_forward_layernorm.hpp Source File
device_sparse_embeddings_forward_layernorm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
15#if __clang_major__ >= 20
17#else
19#endif
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename EmbType,
26 typename IndexType,
27 typename GammaDataType,
28 typename BetaDataType,
29 typename AccDataType,
30 typename OutType,
31 typename EmbElementwiseOperation,
32 ck::index_t BlockSize,
33 ck::index_t DimClusterSize,
34 ck::index_t RowClusterSize,
35 ck::index_t DimPerBlock,
36 ck::index_t RowPerBlock,
37 ck::index_t DimThreadSize,
38 ck::index_t RowVectorSize,
39 ck::index_t NumEmbeddings>
41{
42 static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
43 {
44 return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
45 }
46
47 struct Argument : public BaseArgument
48 {
49 Argument(OutType* p_out,
52 const GammaDataType* p_gamma,
53 const BetaDataType* p_beta,
54 const ck::index_t EmbeddingDim,
55 const ck::index_t IndexLength,
56 const AccDataType epsilon,
57 const EmbElementwiseOperation emb_elementwise_op)
58 : p_out_(p_out),
59 p_embs_(p_embs),
60 p_indexs_(p_indexs),
61 p_gamma_(p_gamma),
62 p_beta_(p_beta),
63 EmbeddingDim_(EmbeddingDim),
64 IndexLength_(IndexLength),
65 epsilon_(epsilon),
66 emb_elementwise_op_(emb_elementwise_op)
67 {
68 grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
69 }
70
71 OutType* p_out_;
74 const GammaDataType* p_gamma_;
75 const BetaDataType* p_beta_;
78 AccDataType epsilon_;
79 EmbElementwiseOperation emb_elementwise_op_;
80
81 size_t grid_size_;
82 };
83
84 std::unique_ptr<BaseArgument>
88 const void* p_gamma,
89 const void* p_beta,
90 ck::index_t EmbeddingDim,
91 ck::index_t IndexLength,
92 const AccDataType epsilon,
93 const EmbElementwiseOperation emb_elementwise_op)
94 {
95 return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
96 p_embs,
97 p_indexs,
98 reinterpret_cast<const GammaDataType*>(p_gamma),
99 reinterpret_cast<const BetaDataType*>(p_beta),
100 EmbeddingDim,
101 IndexLength,
102 epsilon,
103 emb_elementwise_op);
104 }
105
108 IndexType,
109 GammaDataType,
110 BetaDataType,
111 AccDataType,
112 OutType,
113 decltype(MakeOutputDescriptor(1, 1)),
114 EmbElementwiseOperation,
115 BlockSize,
116 DimClusterSize,
117 RowClusterSize,
118 DimPerBlock,
119 RowPerBlock,
120 DimThreadSize,
121 RowVectorSize,
122 NumEmbeddings>;
123
124 struct Invoker : public BaseInvoker
125 {
126 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
127 {
128 auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
129 const auto kernel_main =
131 EmbType,
132 IndexType,
133 GammaDataType,
134 BetaDataType,
135 AccDataType,
136 OutType,
137 decltype(out_desc),
138 EmbElementwiseOperation,
139 NumEmbeddings>;
140 float avg_time = 0;
141 avg_time += launch_and_time_kernel(stream_config,
142 kernel_main,
143 dim3(arg.grid_size_),
144 dim3(BlockSize),
145 0,
146 arg.p_out_,
147 arg.p_embs_,
148 arg.p_indexs_,
149 arg.p_gamma_,
150 arg.p_beta_,
151 out_desc,
152 arg.epsilon_,
154
155 return (avg_time);
156 }
157
158 float Run(const BaseArgument* p_arg,
159 const StreamConfig& stream_config = StreamConfig{}) override
160 {
161 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
162 };
163 };
164
165 static bool IsSupportedArgument(const Argument* p_arg)
166 {
167 return (RowPerBlock == p_arg->EmbeddingDim_);
168 }
169
170 bool IsSupportedArgument(const BaseArgument* p_arg) override
171 {
172 return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
173 }
174
175 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer()
176 {
177 return std::make_unique<Invoker>();
178 }
179
180 std::string GetTypeString() const override
181 {
182 auto str = std::stringstream();
183
184 // clang-format off
185 str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" <<
186 DimClusterSize << "x" << RowClusterSize << "_" <<
187 DimPerBlock << "x" << RowPerBlock << "_" <<
188 DimThreadSize << "x" << RowVectorSize;
189 // clang-format on
190
191 return str.str();
192 }
193};
194
195} // namespace device
196} // namespace tensor_operation
197} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_sparse_embeddings_forward_layernorm(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc out_grid_desc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:57
Definition device_base.hpp:197
Definition device_sparse_embeddings_forward_layernorm.hpp:48
const GammaDataType * p_gamma_
Definition device_sparse_embeddings_forward_layernorm.hpp:74
ck::index_t IndexLength_
Definition device_sparse_embeddings_forward_layernorm.hpp:77
ck::Array< EmbType *, NumEmbeddings > p_embs_
Definition device_sparse_embeddings_forward_layernorm.hpp:72
size_t grid_size_
Definition device_sparse_embeddings_forward_layernorm.hpp:81
OutType * p_out_
Definition device_sparse_embeddings_forward_layernorm.hpp:71
const BetaDataType * p_beta_
Definition device_sparse_embeddings_forward_layernorm.hpp:75
ck::index_t EmbeddingDim_
Definition device_sparse_embeddings_forward_layernorm.hpp:76
Argument(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const GammaDataType *p_gamma, const BetaDataType *p_beta, const ck::index_t EmbeddingDim, const ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition device_sparse_embeddings_forward_layernorm.hpp:49
AccDataType epsilon_
Definition device_sparse_embeddings_forward_layernorm.hpp:78
ck::Array< IndexType *, NumEmbeddings > p_indexs_
Definition device_sparse_embeddings_forward_layernorm.hpp:73
EmbElementwiseOperation emb_elementwise_op_
Definition device_sparse_embeddings_forward_layernorm.hpp:79
Definition device_sparse_embeddings_forward_layernorm.hpp:125
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_sparse_embeddings_forward_layernorm.hpp:126
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_sparse_embeddings_forward_layernorm.hpp:158
Definition device_sparse_embeddings_forward_layernorm.hpp:41
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const void *p_gamma, const void *p_beta, ck::index_t EmbeddingDim, ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition device_sparse_embeddings_forward_layernorm.hpp:85
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()
Definition device_sparse_embeddings_forward_layernorm.hpp:175
static bool IsSupportedArgument(const Argument *p_arg)
Definition device_sparse_embeddings_forward_layernorm.hpp:165
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_sparse_embeddings_forward_layernorm.hpp:170
std::string GetTypeString() const override
Definition device_sparse_embeddings_forward_layernorm.hpp:180
GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings > GridwiseSparseEmbedding
Definition device_sparse_embeddings_forward_layernorm.hpp:106
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
Definition device_sparse_embeddings_forward_layernorm.hpp:42