tensor_partition.hpp Source File

tensor_partition.hpp Source File#

Composable Kernel: tensor_partition.hpp Source File
tensor_partition.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "tensor_utils.hpp"
7#include "layout_utils.hpp"
8
11
12// Disable from doxygen docs generation
14namespace ck {
15namespace wrapper {
17
18// Disable from doxygen docs generation
20namespace {
21
22namespace detail {
23
32template <typename... Ts, typename... Ls>
33__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
34 const Tuple<Ls...>& thread_lengths)
35{
36 static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
37 return generate_tuple(
38 [&](auto i) {
39 constexpr auto num_i = Number<i>{};
40 const auto slice_len =
41 ck::math::integer_divide_ceil(size<num_i>(shape), thread_lengths.At(num_i));
42 return slice_len;
43 },
44 Number<Tuple<Ls...>::Size()>{});
45}
46
56template <typename MultiIndex, typename ProjectionTuple>
57__host__ __device__ constexpr auto
58ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
59 [[maybe_unused]] const ProjectionTuple& projection)
60{
61 if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
62 {
63 return Tuple<>{};
64 }
65 else
66 {
67 auto base_tuple_after_projection = generate_tuple(
68 [&](auto i) {
69 const auto i_num = Number<i.value>{};
70 static_assert(
71 is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
72 is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
73 if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
74 {
75 // When slice (to remove), then insert empty tuple (will be removed in next
76 // step).
77 return Tuple<>{};
78 }
79 else
80 {
81 return make_tuple(base_tuple.At(i_num));
82 }
83 },
84 Number<MultiIndex::Size()>{});
85 // Remove empty tuples
86 return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
87 }
88}
89
99template <typename... Ts, typename... Ps>
100__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts...>& shape,
101 const Tuple<Ps...>& projection)
102{
103 return generate_tuple(
104 [&](auto i) {
105 if constexpr(is_detected<is_slice, tuple_element_t<i, Tuple<Ps...>>>::value)
106 {
107 return size<i>(projection).to_;
108 }
109 else
110 {
111 // number of shape element in actual fragment of shape and projection (method to
112 // calculate shape idx)
113 constexpr index_t shape_i =
114 detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
115 TupleSlice<0, i>(Tuple<Ps...>{}))
116 .Size();
117 return size<shape_i>(shape);
118 }
119 },
120 Number<Tuple<Ps...>::Size()>{});
121}
122
130template <typename... Ts, typename... Ls, typename... Ps>
131__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
132 const Tuple<Ls...>& tile_shape)
133{
134 return generate_tuple(
135 [&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
136 Number<Tuple<Ls...>::Size()>{});
137}
138
147template <typename ThreadIdxs, typename PartitionLengthsSeq, typename OldOffsetIdxs>
148__host__ __device__ constexpr auto
149CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
150 const PartitionLengthsSeq& partition_lengths_seq,
151 const OldOffsetIdxs& old_offset_idxs)
152{
153 return thread_idxs * partition_lengths_seq + old_offset_idxs;
154}
155
162template <typename BlockIdxs>
163__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
164{
165 const auto dims_to_partition = generate_tuple(
166 [&](auto i) {
167 if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
168 {
169 return Number<i>{};
170 }
171 else
172 {
173 return Tuple<>{};
174 }
175 },
176 Number<BlockIdxs::Size()>{});
177 // Remove empty tuples
178 return UnrollNestedTuple<0, 1>(dims_to_partition);
179}
180
187template <typename BlockIdxs>
188__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
189{
190 return generate_tuple(
191 [&](auto i) {
192 if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
193 {
194 return block_idxs.At(i);
195 }
196 else
197 {
198 return Number<0>{};
199 }
200 },
201 Number<BlockIdxs::Size()>{});
202}
203
210template <typename TileShape>
211__host__ __device__ constexpr auto
212GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
213{
214 return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
215}
216
224template <typename ThreadShape, typename ThreadUnrolledDesc>
225__host__ __device__ constexpr auto CalculateThreadMultiIdx(
226 [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
227 const index_t thread_id)
228{
229 static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
230 "Thread layout should not be transformed.");
231 constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
232 constexpr auto shape = ThreadShape{};
233 constexpr auto strides = embed_transform.coefficients_;
234
235 return generate_tuple(
236 [&](auto i) {
237 constexpr auto num_i = Number<i>{};
238 return (thread_id / strides.At(num_i)) % shape.At(num_i);
239 },
240 Number<ThreadShape::Size()>{});
241}
242} // namespace detail
243} // namespace
245
258template <typename TensorType,
259 typename ThreadShape,
260 typename ThreadUnrolledDesc,
261 typename ProjectionTuple>
262__host__ __device__ constexpr auto
263make_local_partition(TensorType& tensor,
264 [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
265 const index_t thread_id,
266 const ProjectionTuple& projection)
267{
268 static_assert(!IsNestedTuple(ThreadShape{}));
269 // Calculate new partition shape
270 const auto& tensor_shape = shape(tensor);
271 // Calculate projected thread lengths
272 constexpr auto projected_thread_lengths =
273 detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
274 constexpr auto partition_shape =
275 detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
276 constexpr auto partition_shape_seq =
277 generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
278 Number<decltype(partition_shape)::Size()>{});
279 // Calculate thread idxs and offsets
280 const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
281 // Apply projection on thread idxs to remove not needed idxs
282 const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
283 const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
284 projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
285 // Create new layout and tensor
286 auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
287 // Slice descriptor
288 const auto transforms = generate_tuple(
289 [&](auto i) {
290 return make_slice_transform(partition_shape.At(i),
291 offset_multi_idxs.At(i),
292 partition_shape.At(i) + offset_multi_idxs.At(i));
293 },
294 Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
295 const auto lower_upper_dims =
296 generate_tuple([&](auto i) { return Sequence<i.value>{}; },
297 Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
298 auto sliced_desc =
299 transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
300 // Create layout
301 const auto partition_layout =
302 Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
303 partition_shape, sliced_desc);
304 auto partition_tensor =
305 make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
306 // Apply offsets
307 return partition_tensor;
308}
309
319template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
320__host__ __device__ constexpr auto
321make_local_partition(TensorType& tensor,
322 const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
323 const index_t thread_id)
324{
325 const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
326 return make_local_partition(tensor, thread_lengths, thread_id, projection);
327}
328
346template <typename TensorType,
347 typename BlockShapeTuple,
348 typename BlockIdxs,
349 typename ProjectionTuple>
350__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
351 const BlockShapeTuple& tile_shape,
352 const BlockIdxs& block_idxs,
353 const ProjectionTuple& projection)
354{
355 static_assert(!IsNestedTuple(BlockShapeTuple{}));
356 static_assert(!IsNestedTuple(BlockIdxs{}));
357
358 constexpr auto I0 = Number<0>{};
359 constexpr auto I1 = Number<1>{};
360 constexpr auto I2 = Number<2>{};
361
362 auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
363
364 constexpr auto projected_tile_shape =
365 detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
366 // Number of dims which are partitioned
367 constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
368 const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
369 if constexpr(decltype(dims_to_partition)::Size() == I2)
370 {
371 const auto shape_with_projection_dims =
372 detail::CalculateShapeWithProjection(shape(tensor), projection);
373 // Set Value for M, N partition
374 const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
375 const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
376 constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
377 constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
378 auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
379 // Get 1D block id
380 const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
381 const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
382 const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
383 // Optimized version for 2d tile shape [MxN]
384 const auto block_2_tile_map =
385 BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
386 NPerBlock,
387 remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
388 const auto block_work_idx =
389 block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
390 const index_t m_block_data_idx_on_grid =
391 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
392 const index_t n_block_data_idx_on_grid =
393 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
394 // Apply 0 for non partitioned dims
395 const auto offset_multi_idxs = generate_tuple(
396 [&](auto i) {
397 if constexpr(i == dims_to_partition.At(I0))
398 {
399 return m_block_data_idx_on_grid;
400 }
401 else if constexpr(i == dims_to_partition.At(I1))
402 {
403 return n_block_data_idx_on_grid;
404 }
405 else
406 {
407 return Number<0>{};
408 }
409 },
410 Number<BlockShapeTuple::Size()>{});
411 const auto projected_offset_multi_idxs =
412 detail::ApplyProjection(offset_multi_idxs, projection);
413 // Create new layout and tensor
414 const auto tile_layout =
415 Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
416 projected_tile_shape, aligned_desc);
417 auto tile_tensor =
418 make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
419 // Apply offsets
420 tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
421 return tile_tensor;
422 }
423 else
424 {
425 // Calculate offsets
426 // Sequence with data to process per block
427 using ProjectedTileShapeTuple = decltype(projected_tile_shape);
428 constexpr auto projected_tile_shape_seq =
429 generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
430 Number<ProjectedTileShapeTuple::Size()>{});
431 // Tuple with number of blocks
432 const auto projected_block_idxs =
433 to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
434 const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
435 projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
436 // Create new layout and tensor
437 const auto tile_layout =
439 projected_tile_shape, aligned_desc);
440 auto tile_tensor =
441 make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
442 // Apply offsets
443 tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
444 return tile_tensor;
445 }
446}
447
460template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
461__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
462 const BlockShapeTuple& tile_shape,
463 const BlockIdxs& block_idxs)
464{
465 const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
466 return make_local_tile(tensor, tile_shape, block_idxs, projection);
467}
468
469} // namespace wrapper
470} // namespace ck
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
static constexpr T value
Definition utility/integral_constant.hpp:21
__host__ __device__ constexpr auto make_local_partition(TensorType &tensor, const Layout< ThreadShape, ThreadUnrolledDesc > &thread_layout, const index_t thread_id, const ProjectionTuple &projection)
Create local partition for thread (At now only packed partition is supported).
Definition tensor_partition.hpp:263
__host__ __device__ constexpr auto make_local_tile(const TensorType &tensor, const BlockShapeTuple &tile_shape, const BlockIdxs &block_idxs, const ProjectionTuple &projection)
Create local tile for thread block. (At now only packed tile is supported).
Definition tensor_partition.hpp:350
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162
constexpr auto make_tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Make tensor function.
Definition tensor_utils.hpp:112