device_elementwise_scale_impl.hpp Source File

device_elementwise_scale_impl.hpp Source File#

Composable Kernel: device_elementwise_scale_impl.hpp Source File
device_elementwise_scale_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
9#include "ck/utility/math.hpp"
14
17
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
26template <typename InDataTypeTuple,
27 typename OutDataTypeTuple,
28 typename ElementwiseOperation,
29 typename UnaryOperation,
30 typename Scale,
31 index_t NumDim,
32 index_t MPerThread,
33 typename InScalarPerVectorSeq,
34 typename OutScalarPerVectorSeq>
35struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
36 OutDataTypeTuple,
37 ElementwiseOperation,
38 UnaryOperation,
39 Scale,
40 NumDim>
41{
42 static constexpr int NumInput = InDataTypeTuple::Size();
43 static constexpr int NumOutput = OutDataTypeTuple::Size();
44
45 static_assert(NumInput == InScalarPerVectorSeq::Size() &&
46 NumOutput == OutScalarPerVectorSeq::Size(),
47 "Tuple size is inconsistent with the number of in/out!");
48
50 {
51 return generate_tuple(
52 [&](auto I) {
53 using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
54
55 return static_cast<const DataType*>(nullptr);
56 },
58 };
59
61 {
62 return generate_tuple(
63 [&](auto I) {
64 using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
65
66 return static_cast<DataType*>(nullptr);
67 },
69 };
70
73
74 template <typename Desc_M>
75 static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
76 {
77 constexpr auto I0 = Number<0>{};
78
79 const auto m = desc_m.GetLength(I0);
80 const index_t loop_step = gridSize * blockSize * MPerThread;
81 const auto pad = math::integer_least_multiple(m, loop_step) - m;
82 const auto desc_m_pad =
87 return desc_m_pad;
88 }
89
90 static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
91 const std::array<index_t, NumDim>& stride,
92 index_t gridSize,
93 index_t blockSize)
94 {
95 auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
96 auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
97
98 // nd desc - [s0, s1, s2, ...]
99 const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
100
101 // merge nd to 1d desc - [s0 * s1 * ...]
102 if constexpr(NumDim > 1)
103 {
104 const auto desc_m = transform_tensor_descriptor(
105 desc,
106 make_tuple(make_merge_transform(tupleOfShape)),
107 make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
109
110 return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
111 }
112 else
113 return PadDescriptor_M_1d(desc, gridSize, blockSize);
114 }
115
116 template <index_t TupleSize>
118 {
119 return generate_tuple(
120 [&](auto) {
121 if constexpr(NumDim > 1)
122 {
123 return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
124 }
125 else
126 {
127 return MakeDescriptor_M({1}, {1}, 1, 1);
128 };
129 },
131 };
132
135
140 ElementwiseOperation,
141 UnaryOperation,
142 Scale,
143 MPerThread,
144 InScalarPerVectorSeq,
145 OutScalarPerVectorSeq>;
146
147 struct Argument : public BaseArgument
148 {
149 Argument(const std::array<index_t, NumDim> lengths,
150 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
151 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
152 const std::array<const void*, NumInput> in_dev_buffers,
153 const std::array<void*, NumOutput> out_dev_buffers,
154 ElementwiseOperation elementwise_op,
155 UnaryOperation unary_op,
156 Scale scale_op)
157
158 : lengths_(lengths),
159 inStridesArray_(inStridesArray),
160 outStridesArray_(outStridesArray),
161 elementwise_op_(elementwise_op),
162 unary_op_(unary_op),
163 scale_op_(scale_op),
164 blockSize_(256)
165 {
167 [&](auto I) {
168 using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
169 return static_cast<const DataType*>(in_dev_buffers[I.value]);
170 },
172
174 [&](auto I) {
175 using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
176 return static_cast<DataType*>(out_dev_buffers[I.value]);
177 },
179 }
180
183
184 std::array<index_t, NumDim> lengths_;
185 std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
186 std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
187
188 ElementwiseOperation elementwise_op_;
189 UnaryOperation unary_op_;
192 };
193
194 struct Invoker : public BaseInvoker
195 {
196 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
197 {
198 index_t gridSize = getAvailableComputeUnitCount(stream_config);
199
200 auto in_grid_1d_desc_tuple = generate_tuple(
201 [&](auto I) {
202 return MakeDescriptor_M(
203 arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
204 },
206
207 auto out_grid_1d_desc_tuple = generate_tuple(
208 [&](auto I) {
209 return MakeDescriptor_M(
210 arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
211 },
213
214 const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
219 ElementwiseOperation,
220 UnaryOperation,
221 Scale>;
222
223 float elapsed_time = launch_and_time_kernel(stream_config,
224 kernel,
225 dim3(gridSize),
226 dim3(arg.blockSize_),
227 0,
228 in_grid_1d_desc_tuple,
229 out_grid_1d_desc_tuple,
230 arg.in_dev_buffers_,
232 arg.elementwise_op_,
233 arg.unary_op_,
234 arg.scale_op_);
235 return elapsed_time;
236 }
237
238 // polymorphic
239 float Run(const BaseArgument* p_arg,
240 const StreamConfig& stream_config = StreamConfig{}) override
241 {
242 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
243 }
244 };
245
246 static bool IsSupportedArgument(const Argument& arg)
247 {
248 if(arg.lengths_.back() % MPerThread != 0)
249 return false;
250
251 auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
252 const std::array<index_t, NumDim>& strides,
253 index_t scalarPerVector) {
254 if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
255 return true;
256
257 if(strides.back() != 1 && scalarPerVector == 1)
258 return true;
259
260 return false;
261 };
262
263 bool valid = true;
264 static_for<0, NumInput, 1>{}([&](auto I) {
265 if(!IsScalarPerVectorValid(
266 arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
267 valid = false;
268 });
269
270 static_for<0, NumOutput, 1>{}([&](auto I) {
271 if(!IsScalarPerVectorValid(
272 arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
273 valid = false;
274 });
275
276 return valid;
277 };
278
279 bool IsSupportedArgument(const BaseArgument* p_arg) override
280 {
281 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
282 }
283
284 static auto
285 MakeArgument(const std::array<index_t, NumDim> lengths,
286 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
287 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
288 const std::array<const void*, NumInput> in_dev_buffers,
289 const std::array<void*, NumOutput> out_dev_buffers,
290 ElementwiseOperation elementwise_op,
291 UnaryOperation unary_op,
292 Scale scale_op)
293 {
294 return Argument{lengths,
295 inStridesArray,
296 outStridesArray,
297 in_dev_buffers,
298 out_dev_buffers,
299 elementwise_op,
300 unary_op,
301 scale_op};
302 }
303
304 std::unique_ptr<BaseArgument>
305 MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
306 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
307 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
308 const std::array<const void*, NumInput> in_dev_buffers,
309 const std::array<void*, NumOutput> out_dev_buffers,
310 ElementwiseOperation elementwise_op,
311 UnaryOperation unary_op,
312 Scale scale_op) override
313 {
314 return std::make_unique<Argument>(lengths,
315 inStridesArray,
316 outStridesArray,
317 in_dev_buffers,
318 out_dev_buffers,
319 elementwise_op,
320 unary_op,
321 scale_op);
322 }
323
324 static auto MakeInvoker() { return Invoker{}; }
325 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
326 {
327 return std::make_unique<Invoker>();
328 };
329
330 std::string GetTypeString() const override
331 {
332 auto str = std::stringstream();
333
334 // clang-format off
335 str << "DeviceElementwiseNormalizationImpl<";
336 str << NumDim << ", ";
337 str << MPerThread << ">";
338 // clang-format on
339
340 return str.str();
341 }
342}; // namespace device
343
344} // namespace device
345} // namespace tensor_operation
346} // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition gridwise_elementwise_1d_scale.hpp:21
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_elementwise_1d_scale.hpp:49
Definition utility/sequence.hpp:43
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_base.hpp:208
Definition device_elementwise.hpp:21
Definition device_elementwise_dynamic_vector_dims_impl.hpp:214
Scale scale_op_
Definition device_elementwise_scale_impl.hpp:190
InDataTypePointerTuple in_dev_buffers_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:242
UnaryOperation unary_op_
Definition device_elementwise_scale_impl.hpp:189
std::array< index_t, NumDim > lengths_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:245
Argument(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op)
Definition device_elementwise_scale_impl.hpp:149
OutDataTypePointerTuple out_dev_buffers_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:243
index_t blockSize_
Definition device_elementwise_scale_impl.hpp:191
ElementwiseOperation elementwise_op_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:249
std::array< std::array< index_t, NumDim >, NumInput > inStridesArray_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:246
std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:247
Definition device_elementwise_dynamic_vector_dims_impl.hpp:253
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_elementwise_scale_impl.hpp:196
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_elementwise_scale_impl.hpp:239
Definition device_elementwise_dynamic_vector_dims_impl.hpp:37
static auto MakeInvoker()
Definition device_elementwise_scale_impl.hpp:324
decltype(GenerateInOutGrid1dDescTuple(Number< NumInput >{})) InGrid1dDescTuple
Definition device_elementwise_scale_impl.hpp:133
static auto MakeArgument(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op)
Definition device_elementwise_scale_impl.hpp:285
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
Definition device_elementwise_scale_impl.hpp:75
static constexpr auto I0
Definition device_elementwise_dynamic_vector_dims_impl.hpp:41
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op) override
Definition device_elementwise_scale_impl.hpp:305
decltype(GenerateInDataTypePointerTuple()) InDataTypePointerTuple
Definition device_elementwise_dynamic_vector_dims_impl.hpp:70
static auto MakeDescriptor_M(const std::array< index_t, NumDim > &lengths, const std::array< index_t, NumDim > &stride, index_t gridSize, index_t blockSize)
Definition device_elementwise_scale_impl.hpp:90
decltype(GenerateInOutGrid1dDescTuple(Number< NumOutput >{})) OutGrid1dDescTuple
Definition device_elementwise_scale_impl.hpp:134
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition device_elementwise_dynamic_vector_dims_impl.hpp:71
static auto GenerateInOutGrid1dDescTuple(Number< TupleSize >)
Definition device_elementwise_scale_impl.hpp:117
static constexpr int NumInput
Definition device_elementwise_dynamic_vector_dims_impl.hpp:38
static bool IsSupportedArgument(const Argument &arg)
Definition device_elementwise_scale_impl.hpp:246
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_elementwise_scale_impl.hpp:325
GridwiseElementwise_1D< InGrid1dDescTuple, OutGrid1dDescTuple, InDataTypePointerTuple, OutDataTypePointerTuple, ElementwiseOperation, UnaryOperation, Scale, MPerThread, InScalarPerVectorSeq, OutScalarPerVectorSeq > GridwiseElementwise
Definition device_elementwise_scale_impl.hpp:136
std::string GetTypeString() const override
Definition device_elementwise_scale_impl.hpp:330
static constexpr int NumOutput
Definition device_elementwise_dynamic_vector_dims_impl.hpp:39
static auto GenerateInDataTypePointerTuple()
Definition device_elementwise_scale_impl.hpp:49
static auto GenerateOutDataTypePointerTuple()
Definition device_elementwise_scale_impl.hpp:60
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_elementwise_scale_impl.hpp:279