gridwise_permute.hpp Source File

gridwise_permute.hpp Source File#

Composable Kernel: gridwise_permute.hpp Source File
gridwise_permute.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <numeric>
8#include <iterator>
9
15
16namespace ck {
17
18template <typename GridwisePermute,
19 typename InGridDesc,
20 typename OutGridDesc,
21 typename InDataType,
22 typename OutDataType,
23 typename ElementwiseOperation,
24 typename Block2TileMap>
25__global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
26 const OutGridDesc out_grid_desc,
27 const InDataType* p_in_global,
28 OutDataType* p_out_global,
29 const ElementwiseOperation elementwise_op,
30 const Block2TileMap block_2_tile_map)
31{
32 __shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()];
33
34 GridwisePermute::Run(in_grid_desc,
35 out_grid_desc,
36 p_in_global,
37 p_out_global,
38 p_shared,
39 elementwise_op,
40 block_2_tile_map);
41}
42
43template <typename InGridDesc,
44 typename OutGridDesc,
45 typename InDataType,
46 typename OutDataType,
47 typename ElementwiseOperation,
48 index_t BlockSize,
49 index_t NPerBlock,
50 index_t HPerBlock,
51 index_t WPerBlock,
52 index_t InBlockLdsExtraW,
53 typename InBlockTransferThreadClusterLengths,
54 typename InBlockTransferThreadClusterArrangeOrder,
55 index_t SrcVectorDim,
56 index_t DstVectorDim,
57 index_t SrcScalarPerVector,
58 index_t DstScalarPerVector>
60{
61 static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
62 static_assert(3 <= InGridDesc::GetNumOfDimension());
63 static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
64 SrcVectorDim < InGridDesc::GetNumOfDimension());
65 static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
66 DstVectorDim < OutGridDesc::GetNumOfDimension());
67 static_assert(SrcVectorDim != DstVectorDim);
68
69 static constexpr auto I0 = Number<0>{};
70 static constexpr auto I1 = Number<1>{};
71 static constexpr auto I2 = Number<2>{};
72
74
76 {
77 static constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
78 static_assert(3 <= NumDim);
79
80 static constexpr auto I0 = Number<0>{};
81
82 Block2TileMap() = delete;
83 Block2TileMap(const Block2TileMap&) = default;
85
86 ~Block2TileMap() = default;
87
90
91 explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {}
92
93 __host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const
94 {
95 const auto N0 =
96 math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
97 const auto H0 =
98 math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
99 const auto W0 =
100 math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
101
102 const index_t grid_size = N0 * H0 * W0;
103
104 return grid_size;
105 }
106
107 template <typename TopIdx>
108 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
109 {
110 static_assert(TopIdx::Size() == 1);
111
112 auto block_1d_id = idx_top[I0];
113
114 const auto N0 =
115 math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
116 const auto H0 =
117 math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
118 const auto W0 =
119 math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
120
121 block_1d_id = block_1d_id % (N0 * H0 * W0);
122
123 index_t idx_N0 = block_1d_id / (H0 * W0);
124 index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
125 index_t idx_W0 = block_1d_id % W0;
126
127 return make_tuple(idx_N0, idx_H0, idx_W0);
128 }
129
130 private:
131 const InGridDesc desc_;
132 };
133
135
136 // use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
137 __host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
138 {
141 make_tuple(Number<HPerBlock*(WPerBlock + InBlockLdsExtraW)>{},
143 I1));
144 }
145
146 // for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
147 // into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
148 // [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
149 template <typename GridDesc>
150 __host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
151 {
152 constexpr index_t NumDim = GridDesc::GetNumOfDimension();
153 static_assert(3 <= NumDim);
154
155 const auto merged_desc = transform_tensor_descriptor(
156 desc,
158 [&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
161 make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
162 Sequence<NumDim - 2>{},
163 Sequence<NumDim - 1>{}),
165 return merged_desc;
166 }
167
168 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
169 {
170 constexpr auto in_block_desc_nperblock_hperblock_wperblock =
172
173 return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
174 sizeof(InDataType);
175 }
176
177 __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
178 {
179 return DefaultBlock2TileMap{desc};
180 }
181
182 __host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
183 const OutGridDesc& out_grid_desc)
184 {
185 constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
186
187 // check if we only swap last 2 dimensions
188 bool valid = true;
189 static_for<0, NumDim - 2, 1>{}([&](auto I) {
190 if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
191 {
192 valid = false;
193 }
194 });
195
196 return valid &&
197 (in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
198 out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
199 (in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
200 out_grid_desc.GetLength(Number<NumDim - 1>{}));
201 }
202
203 template <typename Block2TileMap>
204 __device__ static void Run(const InGridDesc in_grid_desc,
205 const OutGridDesc out_grid_desc,
206 const InDataType* p_in_global,
207 OutDataType* p_out_global,
208 void* __restrict__ p_shared,
209 const ElementwiseOperation elementwise_op,
210 const Block2TileMap& block_2_tile_map)
211 {
213 p_in_global, in_grid_desc.GetElementSpaceSize());
214
216 p_out_global, out_grid_desc.GetElementSpaceSize());
217
218 // each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
219 const auto block_work_idx =
221
222 const index_t n_block_data_idx_on_grid =
223 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
224
225 const index_t h_block_data_idx_on_grid =
226 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
227
228 const index_t w_block_data_idx_on_grid =
229 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
230
231 // create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
232 constexpr auto in_block_desc_nperblock_hperblock_wperblock =
234
236 static_cast<InDataType*>(p_shared),
237 in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
238
239 using BlockSliceLengths = Sequence<NPerBlock, HPerBlock, WPerBlock>;
240 using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
241
242 constexpr index_t SrcVectorDimAfterMerge =
243 SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
244 constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
245
247
248 // merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
249 // ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
250 const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
251
252 // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
253 auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
255 ElementwiseOperation,
258 BlockSliceLengths,
259 InBlockTransferThreadClusterLengths,
260 InBlockTransferThreadClusterArrangeOrder,
261 InDataType,
262 InDataType,
263 decltype(in_grid_desc_n_h_w),
264 decltype(in_block_desc_nperblock_hperblock_wperblock),
265 InBlockTransferAccessOrder,
266 InBlockTransferAccessOrder,
267 SrcVectorDimAfterMerge,
268 2,
269 SrcScalarPerVector,
270 1,
271 1,
272 1,
273 true,
274 true>(in_grid_desc_n_h_w,
276 n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
277 PassThrough{},
278 in_block_desc_nperblock_hperblock_wperblock,
279 make_multi_index(0, 0, 0),
280 PassThrough{});
281
282 // merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
283 // ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
284 const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
285
286 // create transposed view of output tensor
287 const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
288 out_grid_desc_n_w_h,
289 make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
290 make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
291 make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
294
295 // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
296 auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
298 ElementwiseOperation,
301 BlockSliceLengths,
302 InBlockTransferThreadClusterLengths,
303 InBlockTransferThreadClusterArrangeOrder,
304 InDataType,
305 OutDataType,
306 decltype(in_block_desc_nperblock_hperblock_wperblock),
307 decltype(out_grid_desc_n_h_w),
308 InBlockTransferAccessOrder,
309 InBlockTransferAccessOrder,
310 2,
311 DstVectorDimAfterMerge,
312 1,
313 DstScalarPerVector,
314 1,
315 1,
316 true,
317 true>(in_block_desc_nperblock_hperblock_wperblock,
318 make_multi_index(0, 0, 0),
319 PassThrough{},
320 out_grid_desc_n_h_w,
322 n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
323 elementwise_op);
324
325 in_global_load.Run(in_grid_desc_n_h_w,
326 in_global_buf,
327 in_block_desc_nperblock_hperblock_wperblock,
328 in_block_buf,
329 I0);
330
331 out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
332 in_block_buf,
333 out_grid_desc_n_h_w,
334 out_global_buf,
335 I0);
336 }
337};
338
339} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, const ElementwiseOperation elementwise_op, const Block2TileMap block_2_tile_map)
Definition gridwise_permute.hpp:25
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_permute.hpp:76
Block2TileMap & operator=(const Block2TileMap &)=delete
Block2TileMap(const InGridDesc &desc)
Definition gridwise_permute.hpp:91
__host__ constexpr index_t CalculateGridSize(const InGridDesc &desc) const
Definition gridwise_permute.hpp:93
static constexpr index_t NumDim
Definition gridwise_permute.hpp:77
Block2TileMap & operator=(Block2TileMap &&)=delete
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition gridwise_permute.hpp:108
static constexpr auto I0
Definition gridwise_permute.hpp:80
Block2TileMap(Block2TileMap &&)=delete
Block2TileMap(const Block2TileMap &)=default
Definition gridwise_permute.hpp:60
__host__ static __device__ constexpr bool CheckValidity(const InGridDesc &in_grid_desc, const OutGridDesc &out_grid_desc)
Definition gridwise_permute.hpp:182
static constexpr auto I2
Definition gridwise_permute.hpp:71
static constexpr auto I0
Definition gridwise_permute.hpp:69
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_permute.hpp:73
static constexpr auto I1
Definition gridwise_permute.hpp:70
__host__ static __device__ constexpr auto GetMergedDesc(const GridDesc &desc)
Definition gridwise_permute.hpp:150
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_permute.hpp:168
__host__ static __device__ constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
Definition gridwise_permute.hpp:137
__host__ static __device__ constexpr auto MakeDefaultBlock2TileMap(const InGridDesc &desc)
Definition gridwise_permute.hpp:177
Block2TileMap DefaultBlock2TileMap
Definition gridwise_permute.hpp:134
static __device__ void Run(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, void *__restrict__ p_shared, const ElementwiseOperation elementwise_op, const Block2TileMap &block_2_tile_map)
Definition gridwise_permute.hpp:204
Definition multi_index_transform.hpp:13
Definition utility/sequence.hpp:43
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:143
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340