tensor_space_filling_curve.hpp Source File

tensor_space_filling_curve.hpp Source File#

Composable Kernel: tensor_space_filling_curve.hpp Source File
tensor_space_filling_curve.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/math.hpp"
12
13namespace ck {
14
15template <typename TensorLengths,
16 typename DimAccessOrder,
17 typename ScalarsPerAccess,
18 bool SnakeCurved = true> // # of scalars per access in each dimension
20{
21 static constexpr index_t nDim = TensorLengths::Size();
22
24
25 static constexpr index_t ScalarPerVector =
26 reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{});
27
28 static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
29 static constexpr auto dim_access_order = DimAccessOrder{};
30 static constexpr auto ordered_access_lengths =
32
37
38 static constexpr auto I0 = Number<0>{};
39 static constexpr auto I1 = Number<1>{};
40
41 __host__ __device__ static constexpr index_t GetNumOfAccess()
42 {
43 static_assert(TensorLengths::Size() == ScalarsPerAccess::Size());
44 static_assert(TensorLengths{} % ScalarsPerAccess{} ==
45 typename uniform_sequence_gen<TensorLengths::Size(), 0>::type{});
46
47 return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) /
49 }
50
51 template <index_t AccessIdx1dBegin, index_t AccessIdx1dEnd>
52 static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dBegin>,
54 {
55 static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative");
56 static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0");
57 static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative");
58 static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0");
59
60 constexpr auto idx_begin = GetIndex(Number<AccessIdx1dBegin>{});
61 constexpr auto idx_end = GetIndex(Number<AccessIdx1dEnd>{});
62 return idx_end - idx_begin;
63 }
64
65 template <index_t AccessIdx1d>
66 static __device__ __host__ constexpr auto GetForwardStep(Number<AccessIdx1d>)
67 {
68 static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0");
70 }
71
72 template <index_t AccessIdx1d>
73 static __device__ __host__ constexpr auto GetBackwardStep(Number<AccessIdx1d>)
74 {
75 static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
76
77 return GetStepBetween(Number<AccessIdx1d>{}, Number<AccessIdx1d - 1>{});
78 }
79
80 template <index_t AccessIdx1d>
81 static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>)
82 {
83#if 0
84 /*
85 * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
86 */
87 constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
88#else
89
90 constexpr auto access_strides = container_reverse_exclusive_scan(
92
93 constexpr auto idx_1d = Number<AccessIdx1d>{};
94 // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
95 // idim-th element of multidimensional index.
96 // All constexpr variables have to be captured by VALUE.
97 constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
98 constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
99 auto res = idx_1d.value;
100 auto id = 0;
101
102 static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
103 id = res / access_strides[kdim].value;
104 res -= id * access_strides[kdim].value;
105 });
106
107 return id;
108 };
109
110 constexpr auto id = compute_index_impl(idim);
111 return Number<id>{};
112 };
113
114 constexpr auto ordered_access_idx = generate_tuple(compute_index, Number<nDim>{});
115#endif
116 constexpr auto forward_sweep = [&]() {
118
119 forward_sweep_(I0) = true;
120
121 static_for<1, nDim, 1>{}([&](auto idim) {
122 index_t tmp = ordered_access_idx[I0];
123
125 [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
126
127 forward_sweep_(idim) = tmp % 2 == 0;
128 });
129
130 return forward_sweep_;
131 }();
132
133 // calculate multi-dim tensor index
134 auto idx_md = [&]() {
135 Index ordered_idx;
136
137 static_for<0, nDim, 1>{}([&](auto idim) {
138 ordered_idx(idim) =
139 !SnakeCurved || forward_sweep[idim]
140 ? ordered_access_idx[idim]
141 : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
142 });
143
145 ScalarsPerAccess{};
146 }();
147 return idx_md;
148 }
149
150 // FIXME: rename this function
151 template <index_t AccessIdx1d>
152 static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number<AccessIdx1d>)
153 {
154 constexpr auto idx = GetIndex(Number<AccessIdx1d>{});
155
156 return generate_tuple([&](auto i) { return Number<idx[i]>{}; }, Number<nDim>{});
157 }
158};
159
160} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition utility/container_helper.hpp:213
__host__ __device__ constexpr index_t reduce_on_sequence(Seq, Reduce f, Number< Init >)
Definition utility/sequence.hpp:884
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:152
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr auto GetBackwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:73
static __device__ __host__ constexpr Index GetIndex(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:81
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition utility/sequence.hpp:289