thread_group_tensor_slice_transfer_direct_load.hpp Source File

thread_group_tensor_slice_transfer_direct_load.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_direct_load.hpp Source File
thread_group_tensor_slice_transfer_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
42template <typename ThreadGroup,
43 typename BlockSliceLengths,
44 typename ThreadClusterLengths,
45 typename ThreadClusterArrangeOrder,
46 typename SrcData,
47 typename DstData,
48 typename SrcDesc,
49 typename DstDesc,
50 typename SrcDimAccessOrder,
51 index_t SrcVectorDim,
52 index_t DstVectorDim,
53 index_t ScalarPerVector>
55{
58
59 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
60 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
61
62 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
63 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
64
65 static constexpr auto I0 = Number<0>{};
66 static constexpr auto I1 = Number<1>{};
67
68 static constexpr auto block_slice_lengths = BlockSliceLengths{};
69 static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
70
73 // After a load, each thread moves by `thread_steps` instead of loading the next elements.
74 // It makes the whole wavefront load contiguous memory, what is required for direct loads.
77
78 static __device__ constexpr bool AreThreadClusterLengthsValid()
79 {
80 // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
81 // LDS by the threads from a single wavefront.
82 // Examples (assuming 64 threads in a wavefront, 128 in a thread block):
83 // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
84 // data type = fp32 -> ScalarPerVector = 1
85 // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
86 // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
87 // [0, 4, 0].
88 // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
89 // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
90 // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
91 // data type = fp16 -> ScalarPerVector = 2
92 // NOTE: ThreadClusterLengths must take into account that each thread writes two
93 // elements (single DWORD) along the contiguous dimension.
94 // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
95 // 8 * 2 elements of K1PerBlock and there are only 8;
96 // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
97 // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
98 // writes [1, 0, 0] instead of [0, 8, 0].
99 // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
100 // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
101 // elements = 64 consecutive DWORDs.
102#if defined(__gfx950__)
103 int num_contiguous_dwords = 4;
104#else
105 int num_contiguous_dwords = 1;
106#endif
107 bool is_contiguous = true;
108 static_for<0, nDim, 1>{}([&](auto i) {
109 if(is_contiguous)
110 {
111 num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
112 }
113 if(thread_slice_lengths[nDim - i - 1] > 1)
114 {
115 is_contiguous = false;
116 }
117 });
118 constexpr index_t wavefront_size = get_warp_size();
119 const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
120
121 bool thread_slice_lengths_correct = true;
122 static_for<0, nDim, 1>{}([&](auto i) {
123 if(thread_slice_lengths[i] <= 0)
124 {
125 thread_slice_lengths_correct = false;
126 }
127 });
128
129 return wave_contiguous && thread_slice_lengths_correct;
130 }
131
133 const SrcDesc& src_desc,
134 const Index& src_block_slice_origin,
135 const DstDesc& dst_desc,
136 const Index& dst_block_slice_origin)
137
138 {
140 "Direct load transfer does not support datatypes conversion. Source and "
141 "destination data types must be the same.");
142
143 static_assert(
144 DstVectorDim == nDim - 1,
145 "Direct load transfer requires the destination vector dimension to be the last one.");
146
147 static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
148 "When loading more than one element per thread at once, the contiguous "
149 "dimension must be the same between source and destination.");
150
151 // constexpr auto dword_bytes = 4;
152 // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
153 // static_assert(bytes_per_thread_load == dword_bytes,
154 // "Direct load transfer requires each thread to load exactly a single "
155 // "DWORD of data.");
156
159 nDim == ThreadClusterLengths::Size(),
160 "Inconsistent number of dimensions across lengths and descriptors.");
161
162 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
163 "The number of threads cannot be less than the number of elements in "
164 "thread cluster lengths.");
165
166 // static_assert(
167 // AreThreadClusterLengthsValid(),
168 // "Thread cluster lengths are incorrect. They must be set in a way that allows a single
169 // " "wavefront to write contiguous DWORDs into LDS memory. ");
170
171 const auto thread_cluster_idx =
172 thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
173
174 constexpr auto wave_cluster_lengths = generate_sequence_v2(
175 [&](auto i) {
176 // FIXME: wave parallelism is not always in that dimension.
177 // The ThreadClusterLengths{} must be bigger than wave_num;
178 if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
179 {
180 return Number<ThreadGroup::GetNumOfThread() / 64>{};
181 }
182 else
183 {
184 return I1;
185 }
186 },
187 Number<nDim>{});
188
189 constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
190 constexpr auto wave_single_load_size =
191 wave_thread_cluster_lengths * thread_single_load_size;
192 constexpr auto wave_cluster_desc_ =
193 make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
194
195 const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
196 make_multi_index(ThreadGroup::GetThreadId() / 64));
197
198 const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
199 const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
200
201 SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
202 // We don't need threadwise offset for lds since it was calculate by HW
203 // We still need input the wavewise offset.
204 SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
205 }
206
207 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
208 {
209 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
210 src_slice_origin_ = src_slice_origin_idx;
211 }
212
213 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
214 {
215 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
216 dst_slice_origin_ = dst_slice_origin_idx;
217 }
218
219 __device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
220 {
221 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
222 }
223
224 template <typename SrcBuffer, typename DstBuffer>
225 __device__ void Run(const SrcDesc& src_desc,
226 const SrcBuffer& src_buf,
227 const DstDesc& dst_desc,
228 DstBuffer& dst_buf)
229 {
230 static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
231 "Source data must come from a global memory buffer.");
232 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
233 "Destination data must be stored in an LDS memory buffer.");
234
235 static_assert(
237 "SrcBuffer and SrcData data types must be consistent.");
238 static_assert(
240 "DstBuffer and DstData data types must be consistent.");
241
242 constexpr auto dst_access_lengths = thread_slice_lengths;
243
244 const auto dst_forward_steps = generate_steps(dst_desc, 1);
245 const auto dst_backward_steps = generate_steps(dst_desc, -1);
246 const auto src_forward_steps = generate_steps(src_desc, 1);
247 const auto src_backward_steps = generate_steps(src_desc, -1);
248
249 // Loop over the destination block and copy data.
250 static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
251 const auto src_offset = src_coord_.GetOffset();
252 const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
253
254 // Check if src data is not in the logic padding area.
255 const bool is_src_valid =
257
258 src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
259 dst_buf, src_offset, dst_offset, is_src_valid);
260
261 constexpr auto move_on_dim = [&]() constexpr {
263
264 static_for<0, nDim, 1>{}([&](auto i) {
265 move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
266
267 static_for<i + 1, nDim, 1>{}([&](auto j) {
268 move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
269 });
270 });
271
272 return move_on_dim_;
273 }();
274
275 // Decide whether to move forward or backward.
276 constexpr auto forward_sweep = [&]() {
278
279 forward_sweep_(I0) = true;
280
281 static_for<1, nDim, 1>{}([&](auto i) {
282 index_t tmp = ordered_dst_access_idx[I0];
283
284 static_for<1, i, 1>{}([&](auto j) {
285 tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
286 });
287
288 forward_sweep_(i) = tmp % 2 == 0;
289 });
290
291 return forward_sweep_;
292 }();
293
294 static_for<0, nDim, 1>{}([&](auto i) {
295 if constexpr(move_on_dim[i])
296 {
297 if constexpr(forward_sweep[i])
298 {
299 move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
300 move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
301 }
302 else
303 {
304 move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
305 move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
306 }
307 }
308 });
309 });
310
311 // Reset the destination slice since the entire buffer has been already filled.
312 ResetDstSliceWindow(dst_desc);
313 }
314
315 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
316 {
317 src_slice_origin_ = src_slice_origin_ + step;
318 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
319 }
320
321 template <typename DescType>
322 __device__ auto generate_steps(const DescType& desc, int sign)
323 {
324 return generate_tuple(
325 [&](auto i) {
326 Index step_idx;
327
328 static_for<0, nDim, 1>{}([&](auto j) {
329 step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
330 });
331
332 return make_tensor_coordinate_step(desc, step_idx);
333 },
334 Number<nDim>{});
335 }
336
337 private:
338 static constexpr auto thread_cluster_desc_ =
339 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
340
341 SrcCoord src_coord_;
342 DstCoord dst_coord_;
343 Index src_slice_origin_;
344 Index dst_slice_origin_;
345};
346
347} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:225
static constexpr auto thread_single_load_size
Definition thread_group_tensor_slice_transfer_direct_load.hpp:71
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_direct_load.hpp:56
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:315
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition thread_group_tensor_slice_transfer_direct_load.hpp:59
static constexpr auto thread_cluster_lengths
Definition thread_group_tensor_slice_transfer_direct_load.hpp:69
static __device__ constexpr bool AreThreadClusterLengthsValid()
Definition thread_group_tensor_slice_transfer_direct_load.hpp:78
__device__ void ResetDstSliceWindow(const DstDesc &dst_desc)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:219
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition thread_group_tensor_slice_transfer_direct_load.hpp:60
static constexpr auto block_slice_lengths
Definition thread_group_tensor_slice_transfer_direct_load.hpp:68
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_direct_load.hpp:76
static constexpr auto I0
Definition thread_group_tensor_slice_transfer_direct_load.hpp:65
__device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:132
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition thread_group_tensor_slice_transfer_direct_load.hpp:63
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:213
__device__ auto generate_steps(const DescType &desc, int sign)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:322
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_direct_load.hpp:57
static constexpr auto thread_steps
Definition thread_group_tensor_slice_transfer_direct_load.hpp:75
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition thread_group_tensor_slice_transfer_direct_load.hpp:62
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition thread_group_tensor_slice_transfer_direct_load.hpp:207
static constexpr auto I1
Definition thread_group_tensor_slice_transfer_direct_load.hpp:66
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33
Definition functional3.hpp:97