threadwise_tensor_slice_transfer_v6r1r2.hpp Source File

threadwise_tensor_slice_transfer_v6r1r2.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v6r1r2.hpp Source File
threadwise_tensor_slice_transfer_v6r1r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
14// and sometimes useless instructions:
15// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
16// instead
17// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
18// tensor coordinate instead
19// 3. Don't use a pointer to VGPR buffer, use vector instead
20
21// Assume:
22// 1. src_desc and dst_desc are not known at compile-time
23// 2. SrcBuffer and DstBuffer are DynamicBuffer
24// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
25template <typename SrcData,
26 typename DstData,
27 typename SrcDesc,
28 typename DstDesc,
29 typename ElementwiseOperation,
30 typename SliceLengths,
31 typename DimAccessOrder,
32 index_t VectorDim,
33 index_t ScalarPerVector,
34 bool SrcResetCoordinateAfterRun,
35 bool DstResetCoordinateAfterRun>
37{
38 static constexpr index_t nDim = SliceLengths::Size();
39
41
42 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
43 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
44
45 static constexpr auto I0 = Number<0>{};
46
48 const SrcDesc& src_desc,
49 const Index& src_slice_origin,
50 const DstDesc& dst_desc,
51 const Index& dst_slice_origin,
52 const ElementwiseOperation& element_op)
53 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
54 dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
55 element_op_(element_op)
56 {
57 static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
58 "wrong! cannot evenly divide");
59 }
60
61 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
62 {
63 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
64 }
65
66 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
67 {
68 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
69 }
70
71 template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
72 __device__ void Run(const SrcDesc& src_desc,
73 const SrcBuffer& src_buf,
74 const DstDesc& dst_desc,
75 DstBuffer& dst_buf)
76 {
77 // scalar per access on each dim
78 // TODO: don't use lambda_scalar_per_access
79 constexpr auto scalar_per_access = generate_sequence(
81
82 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
83 DimAccessOrder,
84 remove_cv_t<decltype(scalar_per_access)>>;
85
86 // loop over space-filling curve
87 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
88
89 static_for<0, num_access, 1>{}([&](auto idx_1d) {
91 using src_vector_t = typename src_vector_type::type;
92
94 using dst_vector_t = typename dst_vector_type::type;
95
96 const bool is_src_valid =
98
99 // copy data from src_buf into src_vector_container
100 auto src_vector_container = src_vector_type{
101 src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
102
103 auto dst_vector_container = dst_vector_type{};
104
105 // apply pointwise operation
107 SrcData v;
108
109 // apply element-wise operation
110 element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
111
112 // apply type convert
113 dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
114 });
115
116 const bool is_dst_valid =
118
119 // copy data from dst_vector into dst_buf
120 dst_buf.template Update<DstInMemOp, dst_vector_t>(
121 dst_coord_.GetOffset(),
122 is_dst_valid,
123 dst_vector_container.template AsType<dst_vector_t>()[I0]);
124
125 // move coordinate
126 if constexpr(idx_1d.value != num_access - 1)
127 {
128 constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
130 src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
132 dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
133 }
134 });
135
136 // move coordinate back to slice origin (or not)
137 if constexpr(SrcResetCoordinateAfterRun)
138 {
139 const auto src_reset_step =
141
142 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
143 }
144
145 if constexpr(DstResetCoordinateAfterRun)
146 {
147 const auto dst_reset_step =
149
150 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
151 }
152 }
153
154 __device__ static constexpr auto GetCoordinateResetStep()
155 {
156 constexpr auto scalar_per_access = generate_sequence(
158
159 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
160 DimAccessOrder,
161 remove_cv_t<decltype(scalar_per_access)>>;
162
163 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
164 if constexpr(num_access == 0)
165 {
166 return typename SpaceFillingCurve::Index{};
167 }
168 else
169 {
170 constexpr auto reset_step =
172
173 return reset_step;
174 }
175 }
176
177 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
178 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
179 const Index& src_slice_origin_step_idx)
180 {
181 // if src coord was not reset by RunRead(), then need to adjust the step here
182 const auto adjusted_step_idx = SrcResetCoordinateAfterRun
183 ? src_slice_origin_step_idx
184 : src_slice_origin_step_idx + GetCoordinateResetStep();
185
186 // is it OK to construct a new step every time?
187 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
188
189 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
190 }
191
192 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
193 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
194 const Index& dst_slice_origin_step_idx)
195 {
196 // if dst coord was not reset by Run(), then need to adjust the step here
197 const auto adjusted_step_idx = DstResetCoordinateAfterRun
198 ? dst_slice_origin_step_idx
199 : dst_slice_origin_step_idx + GetCoordinateResetStep();
200
201 // is it OK to construct a new step every time?
202 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
203
204 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
205 }
206
207 private:
208 SrcCoord src_coord_;
209 DstCoord dst_coord_;
210 const ElementwiseOperation element_op_;
211};
212
213} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:193
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:66
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:61
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(const SrcDesc &src_desc, const Index &src_slice_origin, const DstDesc &dst_desc, const Index &dst_slice_origin, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:47
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:72
static __device__ constexpr auto GetCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:154
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v6r1r2.hpp:178
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33