threadwise_tensor_slice_transfer_v4r1.hpp Source File

threadwise_tensor_slice_transfer_v4r1.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v4r1.hpp Source File
threadwise_tensor_slice_transfer_v4r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11// Assume:
12// 1. src:
13// 1. SrcDesc is known at compile-time
14// 2. SrcBuffer is DynamicBuffer
15// 3. src_ref_idx is known at run-time
16// 4. SrcRefToOriginDisplacement is known at compile-time
17// 5. use #-step
18// 2. dst:
19// 1. DstDesc is known at compile-time
20// 2. DstBuffer is StaticBuffer
21// 3. DstOriginIdx is known at compile-time
22// 4. use direct address calculation
23// 3. vector access on src
24template <typename SrcData,
25 typename DstData,
26 typename SrcDesc,
27 typename DstDesc,
28 typename SliceLengths,
29 typename DimAccessOrder,
30 typename SrcVectorTensorLengths,
31 typename SrcVectorTensorContiguousDimOrder,
32 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
33 bool>::type = false>
35{
36 static constexpr auto I0 = Number<0>{};
37 static constexpr auto I1 = Number<1>{};
38
39 static constexpr index_t nDim = SliceLengths::Size();
40
42
43 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
44
45 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
46
47 __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
48 : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
49 {
50 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
51 "wrong! SrcDesc and DstDesc need to known at compile-time");
52
53 static_for<0, nDim, 1>{}([](auto i) {
54 static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!");
55 });
56 }
57
58 template <typename SrcRefToOriginDisplacement,
59 typename DstOriginIdx,
60 typename SrcBuffer,
61 typename DstBuffer>
62 __device__ void Run(const SrcDesc&,
63 const SrcRefToOriginDisplacement&,
64 const SrcBuffer& src_buf,
65 const DstDesc&,
66 const DstOriginIdx&,
67 DstBuffer& dst_buf) const
68 {
69 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
70 "wrong! SrcDesc and DstDesc need to known at compile-time");
71
72 static_assert(
75 "wrong! SrcBuffer or DstBuffer data type is wrong");
76
77 static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
78
81 "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
82 "at compile-time");
83
84 // SrcDesc and DstDesc are known at compile-time
85 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
86 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
87
88 // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
89 constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
90 constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
91
92 // tensor descriptor for src_vector
93 constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
94
95 constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
97 container_reorder_given_new2old(src_vector_tensor_lengths,
98 SrcVectorTensorContiguousDimOrder{}),
100 I1),
101 SrcVectorTensorContiguousDimOrder{});
102
103 constexpr auto src_vector_desc =
105 sequence_to_tuple_of_number(src_vector_tensor_strides));
106
107 // access order and lengths
108 constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
109
110 constexpr auto dim_access_order = DimAccessOrder{};
111
112 constexpr auto ordered_access_lengths =
113 container_reorder_given_new2old(access_lengths, dim_access_order);
114
115 static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
116 // position in slice window
117 constexpr auto data_to_origin_disp_idx =
118 ordered_access_idx.ReorderGivenOld2New(dim_access_order) *
119 src_vector_tensor_lengths;
120
121 // src coordinate at starting point of src_vector
122 constexpr auto src_ref_to_data_disp_idx =
123 src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
124
125 constexpr auto src_ref_to_data_disp_coord_step =
126 make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
127
128 auto src_data_coord = src_ref_coord_;
129
130 move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
131
132 vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
133
134 using src_vector_t = typename decltype(src_vector)::type;
135
137 src_desc, src_data_coord);
138
139 // copy data from src_buf into src_vector
140 src_vector.template AsType<src_vector_t>()(I0) =
141 src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
142
143 // copy data from src_vector into dst_buf (also cast from SrcData to DstData)
144 static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
145 constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
146
147 constexpr index_t src_vector_offset =
148 src_vector_desc.CalculateOffset(src_vector_idx);
149
150 constexpr index_t dst_offset = dst_desc.CalculateOffset(
151 dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
152
154 src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
155 });
156 });
157 }
158
159 template <typename SrcSliceMoveStepIdx>
160 __device__ void MoveSrcSliceWindow(const SrcDesc&,
161 const SrcSliceMoveStepIdx& src_slice_move_step_idx)
162 {
163 constexpr auto src_desc = SrcDesc{};
164
165 const auto src_slice_move_step_iter =
166 make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
167
168 move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
169 }
170
171 private:
172 SrcCoord src_ref_coord_;
173};
174
175} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__host__ __device__ constexpr auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition utility/container_helper.hpp:213
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index &src_ref_idx)
Definition threadwise_tensor_slice_transfer_v4r1.hpp:47
__device__ void MoveSrcSliceWindow(const SrcDesc &, const SrcSliceMoveStepIdx &src_slice_move_step_idx)
Definition threadwise_tensor_slice_transfer_v4r1.hpp:160
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer_v4r1.hpp:62
Definition is_known_at_compile_time.hpp:14
Definition type.hpp:177
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition functional3.hpp:97