threadwise_tensor_slice_transfer_v7.hpp Source File

threadwise_tensor_slice_transfer_v7.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v7.hpp Source File
threadwise_tensor_slice_transfer_v7.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13// Thread-level multi-source, multi-destination tensor slice data movement
14// Assume:
15// 1. All sources and destinations are DynamicBuffer
16// 2. Same VectorDim and ScalerPerVector for all sources and destinations
17// 3. DstInMemOps are per destination tensor
18// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
19// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
20// 6. Does not need to know src_descs and dst_descs at compile-time
21// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
22//
23// Does following things to avoid scratch memory issue
24// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
25// 2. Pass tensor descritpors by reference (or tuple of references)
26// 3. Does not keep reference to tensor descriptor
27// 4. Does not construct new tensor coordinate when call Run()
28template <typename SrcDatas,
29 typename DstDatas,
30 typename SrcDescs,
31 typename DstDescs,
32 typename ElementwiseOperation,
33 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
34 typename SliceLengths,
35 typename DimAccessOrder,
36 index_t VectorDim,
37 index_t ScalarPerVector,
38 typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
39 typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
41{
42 static constexpr auto I0 = Number<0>{};
43
44 static constexpr index_t nDim = SliceLengths::Size();
45
46 static constexpr index_t nSrc = SrcDescs::Size();
47 static constexpr index_t nDst = DstDescs::Size();
48
50
51 // return a tuple of coordiantes for a tuple of tensor
52 template <typename Descs,
53 typename Indices,
54 enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
55 static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
56 {
57 return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
58 Number<Descs::Size()>{});
59 }
60
63
64 // scalar per access on each dim
65 // FIXME: don't use lambda_scalar_per_access
66 static constexpr auto scalar_per_access = generate_sequence(
68
70 SpaceFillingCurve<SliceLengths, DimAccessOrder, remove_cv_t<decltype(scalar_per_access)>>;
71
73 const SrcDescs& src_descs,
74 const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
75 const DstDescs& dst_descs,
76 const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
77 const ElementwiseOperation& element_op)
78 : src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
79 dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
80 element_op_(element_op)
81 {
82 static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
83 "wrong! cannot evenly divide");
84 }
85
86 template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
87 __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
88 const Indices& src_slice_origin_idxs)
89 {
90 static_for<0, nSrc, 1>{}([&](auto i) {
91 src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
92 });
93 }
94
95 template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
96 __device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
97 const Indices& dst_slice_origin_idxs)
98 {
99 static_for<0, nDst, 1>{}([&](auto i) {
100 dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
101 });
102 }
103
104 // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
105 // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
106 // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
107 // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
108 template <typename SrcBuffers,
109 typename DstBuffers,
110 enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
111 DstDescs::Size() == DstBuffers::Size(),
112 bool> = false>
113 __device__ void Run(const SrcDescs& src_descs,
114 const SrcBuffers& src_bufs,
115 const DstDescs& dst_descs,
116 DstBuffers dst_bufs)
117 {
118 auto generate_vectors = [&](auto data_types) {
119 constexpr index_t num = data_types.Size();
120
121 return generate_tuple(
122 [&](auto i) {
123 using DataType = remove_cvref_t<decltype(data_types[i])>;
124
126 },
127 Number<num>{});
128 };
129
130 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
131
132 // loop over space-filling curve
133 static_for<0, num_access, 1>{}([&](auto iAccess) {
134 auto src_vectors = generate_vectors(SrcDatas{});
135 auto dst_vectors = generate_vectors(DstDatas{});
136
137 // copy data from src_bufs into src_vectors
138 static_for<0, nSrc, 1>{}([&](auto i) {
139 using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
140
141 const bool is_src_valid =
143 src_coords_[i]);
144
145 src_vectors(i).template AsType<src_vector_t>()(I0) =
146 src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
147 is_src_valid);
148 });
149
150 // apply pointwise function
152 // get reference to src data
153 const auto src_data_refs = generate_tie(
154 // return type should be lvalue
155 [&](auto iSrc) -> const auto& {
156 using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
157
158 return src_vectors[iSrc].template AsType<SrcData>()[i];
159 },
160 Number<nSrc>{});
161
162 // get reference to dst data
163 auto dst_data_refs = generate_tie(
164 // return type should be lvalue
165 [&](auto iDst) -> auto& {
166 using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
167
168 return dst_vectors(iDst).template AsType<DstData>()(i);
169 },
170 Number<nDst>{});
171
172 // apply pointwise function
173 // pointwise function signature:
174 // element_op_(dst_data_refs[I0],
175 // dst_data_refs[I1],
176 // ...,
177 // src_data_refs[I0],
178 // src_data_refs[I1],
179 // ...)
180 unpack2(element_op_, dst_data_refs, src_data_refs);
181 });
182
183 // copy data from buf_vectors into dst_bufs
184 static_for<0, nDst, 1>{}([&](auto i) {
185 using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
186
187 const bool is_dst_valid =
189 dst_coords_[i]);
190
191 constexpr InMemoryDataOperationEnum DstInMemOp =
192 static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
193
194 dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
195 dst_coords_[i].GetOffset(),
196 is_dst_valid,
197 dst_vectors[i].template AsType<dst_vector_t>()[I0]);
198 });
199
200 // move coordinate
201 if constexpr(iAccess.value != num_access - 1)
202 {
203 constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess);
204
205 static_for<0, nSrc, 1>{}([&](auto i) {
206 move_tensor_coordinate(src_descs[i],
207 src_coords_(i),
208 make_tensor_coordinate_step(src_descs[i], forward_step));
209 });
210
211 static_for<0, nDst, 1>{}([&](auto i) {
212 move_tensor_coordinate(dst_descs[i],
213 dst_coords_(i),
214 make_tensor_coordinate_step(dst_descs[i], forward_step));
215 });
216 }
217 });
218
219 // move coordinate back to slice origin (or not)
220 static_for<0, nSrc, 1>{}([&](auto i) {
221 if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
222 {
223 const auto src_reset_step =
225
226 move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
227 }
228 });
229
230 static_for<0, nDst, 1>{}([&](auto i) {
231 if constexpr(DstResetCoordinateAfterRunFlags::At(i))
232 {
233 const auto dst_reset_step =
235
236 move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
237 }
238 });
239 }
240
241 __device__ static constexpr auto GetCoordinateResetStep()
242 {
243 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
244
245 if constexpr(num_access == 0)
246 {
247 return typename SpaceFillingCurve::Index{};
248 }
249 else
250 {
251 constexpr auto reset_step =
253
254 return reset_step;
255 }
256 }
257
258 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
259 template <index_t ISrc>
260 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
261 Number<ISrc> iSrc,
262 const Index& src_slice_origin_step_idx)
263 {
264 // if src coord was not reset by RunRead(), then need to adjust the step here
265 const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc)
266 ? src_slice_origin_step_idx
267 : src_slice_origin_step_idx + GetCoordinateResetStep();
268
269 // is it OK to construct a new step every time?
270 const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
271
272 move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
273 }
274
275 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
276 template <index_t IDst>
277 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
278 Number<IDst> iDst,
279 const Index& dst_slice_origin_step_idx)
280 {
281 // if dst coord was not reset by Run(), then need to adjust the step here
282 const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst)
283 ? dst_slice_origin_step_idx
284 : dst_slice_origin_step_idx + GetCoordinateResetStep();
285
286 // is it OK to construct a new step every time?
287 const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
288
289 move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
290 }
291
292 private:
293 SrcCoords src_coords_;
294 DstCoords dst_coords_;
295 const ElementwiseOperation element_op_;
296};
297
298} // namespace ck
Definition ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7.hpp:96
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7.hpp:277
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7.hpp:260
__device__ constexpr ThreadwiseTensorSliceTransfer_v7(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v7.hpp:72
static __device__ constexpr auto GetCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7.hpp:241
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v7.hpp:55
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7.hpp:87
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition threadwise_tensor_slice_transfer_v7.hpp:113
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33