space_filling_curve.hpp Source File

space_filling_curve.hpp Source File#

Composable Kernel: space_filling_curve.hpp Source File
space_filling_curve.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck_tile {
14
15template <typename TensorLengths,
16 typename DimAccessOrder,
17 typename ScalarsPerAccess,
18 bool SnakeCurved = true> // # of scalars per access in each dimension
20{
21 static constexpr index_t TensorSize =
22 reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
23 static_assert(0 < TensorSize,
24 "space_filling_curve should be used to access a non-empty tensor");
25
26 static constexpr index_t nDim = TensorLengths::size();
27
29
30 static constexpr index_t ScalarPerVector =
31 reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
32
33 static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
34 static constexpr auto dim_access_order = DimAccessOrder{};
35 static constexpr auto ordered_access_lengths =
37
42
43 static constexpr auto I0 = number<0>{};
44 static constexpr auto I1 = number<1>{};
45
47 {
48 static_assert(TensorLengths::size() == ScalarsPerAccess::size());
49 static_assert(TensorLengths{} % ScalarsPerAccess{} ==
50 typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
51
52 return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
53 }
54
55 template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
58 {
59 static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
60 "1D index out of range");
61 static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
62 "1D index out of range");
63
64 constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
65 constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
66 return idx_tail - idx_head;
67 }
68
69 template <index_t AccessIdx1d>
71 {
72 static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
74 }
75
76 template <index_t AccessIdx1d>
78 {
79 static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
80
81 return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
82 }
83
84 // Do not use this function directly!
85 // TODO: can refactor into generic lambda in the future
86 template <index_t AccessIdx1d>
88 {
89#if 0
90 /*
91 * \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
92 */
93 constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
94#else
95
96 constexpr auto access_strides =
98
99 constexpr auto idx_1d = number<AccessIdx1d>{};
100 // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
101 // idim-th element of multidimensional index.
102 // All constexpr variables have to be captured by VALUE.
103 constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
104 constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
105 auto res = idx_1d.value;
106 auto id = 0;
107
108 static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
109 id = res / access_strides[kdim].value;
110 res -= id * access_strides[kdim].value;
111 });
112
113 return id;
114 };
115
116 constexpr auto id = compute_index_impl(idim);
117 return number<id>{};
118 };
119
120 constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
121#endif
122 constexpr auto forward_sweep = [&]() {
124
125 forward_sweep_(I0) = true;
126
127 static_for<1, nDim, 1>{}([&](auto idim) {
128 index_t tmp = ordered_access_idx[I0];
129
131 [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
132
133 forward_sweep_(idim) = tmp % 2 == 0;
134 });
135
136 return forward_sweep_;
137 }();
138
139 // calculate multi-dim tensor index
140 auto idx_md = [&]() {
141 Index ordered_idx;
142
143 static_for<0, nDim, 1>{}([&](auto idim) {
144 ordered_idx(idim) =
145 !SnakeCurved || forward_sweep[idim]
146 ? ordered_access_idx[idim]
147 : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
148 });
149
151 ScalarsPerAccess{};
152 }();
153 return idx_md;
154 }
155
156 // FIXME: return tuple of number<>, which is compile time only variable
157 template <index_t AccessIdx1d>
159 {
160 constexpr auto idx = _get_index(number<AccessIdx1d>{});
161
162 return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
163 }
164};
165
166} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition tile/core/container/container_helper.hpp:39
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition tile/core/container/container_helper.hpp:48
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_adaptor.hpp:359
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan(const array< TData, NSize > &x, Reduce f, Init init)
Definition tile/core/container/container_helper.hpp:240
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs &&... xs)
Definition tile/core/container/multi_index.hpp:20
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
Definition tile/core/numeric/math.hpp:98
Definition tile/core/container/sequence.hpp:49
Definition space_filling_curve.hpp:20
static CK_TILE_HOST_DEVICE constexpr Index _get_index(number< AccessIdx1d >)
Definition space_filling_curve.hpp:87
static CK_TILE_HOST_DEVICE constexpr auto get_index(number< AccessIdx1d >)
Definition space_filling_curve.hpp:158
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access()
Definition space_filling_curve.hpp:46
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number< AccessIdx1d >)
Definition space_filling_curve.hpp:77
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number< AccessIdx1dHead >, number< AccessIdx1dTail >)
Definition space_filling_curve.hpp:56
multi_index< nDim > Index
Definition space_filling_curve.hpp:28
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number< AccessIdx1d >)
Definition space_filling_curve.hpp:70
Definition tile/core/utility/functional.hpp:43
Definition tile/core/container/sequence.hpp:314