device_permute_impl.hpp Source File

device_permute_impl.hpp Source File#

Composable Kernel: device_permute_impl.hpp Source File
device_permute_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7#include <memory>
8#include <utility>
9
10#include "ck/utility/math.hpp"
17
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// Swap last 2 dimensions
25// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
26// ^^^^^^^^^^^
27// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
28// ^^^^^^^^^^^
29template <index_t NumDim,
30 typename InDataType,
31 typename OutDataType,
32 typename ElementwiseOperation,
33 index_t BlockSize,
34 index_t NPerBlock,
35 index_t HPerBlock,
36 index_t WPerBlock,
37 index_t InBlockLdsExtraW,
38 typename InBlockTransferThreadClusterLengths,
39 typename InBlockTransferThreadClusterArrangeOrder,
40 index_t SrcVectorDim,
41 index_t DstVectorDim,
42 index_t SrcScalarPerVector,
43 index_t DstScalarPerVector>
44struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
45{
47 using typename BaseType::Lengths;
48 using typename BaseType::Strides;
49
50 static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
51 static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
52 static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
53 static_assert(SrcVectorDim != DstVectorDim);
54
55 template <index_t N = NumDim>
56 static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
57 {
58 static_assert(1 <= N && N <= NumDim);
59
60 return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
61 }
62
63 static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
64 {
65 // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
66 // d[NumDim-1]]
67 const auto desc =
69
70 // merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
71 // d[NumDim-1]]
72 // => [N, H, W]
73 const index_t H = *std::next(rbegin(lengths));
74 const index_t W = *rbegin(lengths);
75 const auto desc_n_h_w = transform_tensor_descriptor(
76 desc,
80 make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
81 Sequence<NumDim - 2>{},
82 Sequence<NumDim - 1>{}),
84
86 desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
87 }
88
89 using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
91
95 InDataType,
96 OutDataType,
97 ElementwiseOperation,
98 BlockSize,
99 NPerBlock,
100 HPerBlock,
101 WPerBlock,
102 InBlockLdsExtraW,
103 InBlockTransferThreadClusterLengths,
104 InBlockTransferThreadClusterArrangeOrder,
105 SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
106 DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
107 SrcScalarPerVector,
108 DstScalarPerVector>;
109
111
112 struct Argument : public BaseArgument
113 {
114 Argument(const Lengths& in_lengths,
115 const Strides& in_strides,
116 const Lengths& out_lengths,
117 const Strides& out_strides,
118 const void* in_dev_buffer,
119 void* out_dev_buffer,
120 ElementwiseOperation elementwise_op)
121 : in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
122 out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
123 in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
124 out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
125 in_lengths_(in_lengths),
126 in_strides_(in_strides),
127 out_lengths_(out_lengths),
128 out_strides_(out_strides),
129 elementwise_op_(elementwise_op),
130 block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
131 {
132 }
133
134 const InDataType* in_dev_buffer_;
135 OutDataType* out_dev_buffer_;
138
143
144 ElementwiseOperation elementwise_op_;
145
147 };
148
150 {
151 static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
152 {
153 const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
154
155 const auto kernel = kernel_nd_permute<GridwisePermute,
158 InDataType,
159 OutDataType,
160 ElementwiseOperation,
162
163 float elapsed_time = launch_and_time_kernel(stream_config,
164 kernel,
165 dim3(grid_size),
166 dim3(BlockSize),
167 0,
168 arg.in_grid_desc_,
169 arg.out_grid_desc_,
170 arg.in_dev_buffer_,
171 arg.out_dev_buffer_,
172 arg.elementwise_op_,
174 return elapsed_time;
175 }
176
177 float Run(const BaseArgument* arg,
178 const StreamConfig& stream_config = StreamConfig{}) override final
179 {
180 const auto* const argument = dynamic_cast<const Argument*>(arg);
181 if(!argument)
182 {
183 return NAN;
184 }
185
186 return Run(*argument, stream_config);
187 }
188 };
189
190 static bool IsSupportedArgument(const Argument& arg)
191 {
192 constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
193 return math::integer_divide_ceil(length, tile_length) * tile_length;
194 };
195
196 constexpr auto IsScalarPerVectorValid =
197 [](index_t length, index_t stride, index_t scalar_per_vector) {
198 if(stride == 1 && length % scalar_per_vector == 0)
199 {
200 return true;
201 }
202 else if(stride != 1 && scalar_per_vector == 1)
203 {
204 return true;
205 }
206
207 return false;
208 };
209
210 return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
211 arg.in_strides_[SrcVectorDim],
212 SrcScalarPerVector) &&
213 IsScalarPerVectorValid(
214 GetPaddedLength(arg.in_lengths_[SrcVectorDim],
215 (SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
216 arg.in_strides_[SrcVectorDim],
217 SrcScalarPerVector) &&
218 IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
219 arg.out_strides_[DstVectorDim],
220 DstScalarPerVector) &&
221 IsScalarPerVectorValid(
222 GetPaddedLength(arg.out_lengths_[DstVectorDim],
223 (DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
224 arg.in_strides_[DstVectorDim],
225 DstScalarPerVector) &&
227 };
228
229 // override methods inherited from 'BaseOperator'
230 bool IsSupportedArgument(const BaseArgument* arg) override final
231 {
232 const auto* const argument = dynamic_cast<const Argument*>(arg);
233 if(!argument)
234 {
235 return false;
236 }
237
238 return IsSupportedArgument(*argument);
239 }
240
241 // override methods inherited from 'DevicePermute'
242 std::unique_ptr<BaseArgument>
243 MakeArgumentPointer(const Lengths& in_lengths,
244 const Strides& in_strides,
245 const Lengths& out_lengths,
246 const Strides& out_strides,
247 const void* in_dev_buffer,
248 void* out_dev_buffer,
249 ElementwiseOperation elementwise_op) override final
250 {
251 return std::make_unique<Argument>(in_lengths,
252 in_strides,
253 out_lengths,
254 out_strides,
255 in_dev_buffer,
256 out_dev_buffer,
257 elementwise_op);
258 }
259
260 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
261 {
262 return std::make_unique<Invoker>();
263 };
264
265 // other constructor methods
266 template <typename... Args>
267 static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
268 MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
269 {
270 return Argument{std::forward<Args>(args)...};
271 }
272
273 static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
274 MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
275 {
276 return Invoker{};
277 }
278};
279
280} // namespace device
281} // namespace tensor_operation
282} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, const ElementwiseOperation elementwise_op, const Block2TileMap block_2_tile_map)
Definition gridwise_permute.hpp:25
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
STL namespace.
Definition ck/stream_config.hpp:10
__host__ static __device__ constexpr bool CheckValidity(const InGridDesc &in_grid_desc, const OutGridDesc &out_grid_desc)
Definition gridwise_permute.hpp:182
Block2TileMap DefaultBlock2TileMap
Definition gridwise_permute.hpp:134
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_permute.hpp:18
Lengths Strides
Definition device_permute.hpp:20
std::array< index_t, NumDim > Lengths
Definition device_permute.hpp:19
Definition device_permute_impl.hpp:113
Strides in_strides_
Definition device_permute_impl.hpp:140
ElementwiseOperation elementwise_op_
Definition device_permute_impl.hpp:144
Argument(const Lengths &in_lengths, const Strides &in_strides, const Lengths &out_lengths, const Strides &out_strides, const void *in_dev_buffer, void *out_dev_buffer, ElementwiseOperation elementwise_op)
Definition device_permute_impl.hpp:114
OutGridDesc out_grid_desc_
Definition device_permute_impl.hpp:137
Block2TileMap block_2_tile_map_
Definition device_permute_impl.hpp:146
InGridDesc in_grid_desc_
Definition device_permute_impl.hpp:136
const InDataType * in_dev_buffer_
Definition device_permute_impl.hpp:134
Lengths in_lengths_
Definition device_permute_impl.hpp:139
Strides out_strides_
Definition device_permute_impl.hpp:142
OutDataType * out_dev_buffer_
Definition device_permute_impl.hpp:135
Lengths out_lengths_
Definition device_permute_impl.hpp:141
Definition device_permute_impl.hpp:150
static float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_permute_impl.hpp:151
float Run(const BaseArgument *arg, const StreamConfig &stream_config=StreamConfig{}) override final
Definition device_permute_impl.hpp:177
Definition device_permute_impl.hpp:45
InGridDesc OutGridDesc
Definition device_permute_impl.hpp:90
decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})) InGridDesc
Definition device_permute_impl.hpp:89
static std::enable_if_t< std::is_constructible_v< Argument, Args... >, Argument > MakeArgument(Args &&... args) noexcept(std::is_nothrow_constructible_v< Argument, Args... >)
Definition device_permute_impl.hpp:268
static auto MakeDescriptor_N_H_W(const Lengths &lengths, const Strides &stride)
Definition device_permute_impl.hpp:63
GridwisePermute< InGridDesc, OutGridDesc, InDataType, OutDataType, ElementwiseOperation, BlockSize, NPerBlock, HPerBlock, WPerBlock, InBlockLdsExtraW, InBlockTransferThreadClusterLengths, InBlockTransferThreadClusterArrangeOrder, SrcVectorDim -(NumDim - 3), DstVectorDim -(NumDim - 3), SrcScalarPerVector, DstScalarPerVector > GridwisePermute
Definition device_permute_impl.hpp:92
Lengths Strides
Definition device_permute.hpp:20
bool IsSupportedArgument(const BaseArgument *arg) override final
Definition device_permute_impl.hpp:230
std::unique_ptr< BaseArgument > MakeArgumentPointer(const Lengths &in_lengths, const Strides &in_strides, const Lengths &out_lengths, const Strides &out_strides, const void *in_dev_buffer, void *out_dev_buffer, ElementwiseOperation elementwise_op) override final
Definition device_permute_impl.hpp:243
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override final
Definition device_permute_impl.hpp:260
static auto ConvertArrayToTuple(const std::array< index_t, NumDim > &array)
Definition device_permute_impl.hpp:56
typename GridwisePermute::DefaultBlock2TileMap Block2TileMap
Definition device_permute_impl.hpp:110
static bool IsSupportedArgument(const Argument &arg)
Definition device_permute_impl.hpp:190
DevicePermute< NumDim, InDataType, OutDataType, ElementwiseOperation > BaseType
Definition device_permute_impl.hpp:46
std::array< index_t, NumDim > Lengths
Definition device_permute.hpp:19
static std::enable_if_t< std::is_default_constructible_v< Invoker >, Invoker > MakeInvoker() noexcept(std::is_nothrow_default_constructible_v< Invoker >)
Definition device_permute_impl.hpp:274