31template <
typename SrcDatas,
35 typename ElementwiseOperation,
37 typename SliceLengths,
38 typename SrcDimAccessOrder,
39 typename DstDimAccessOrder,
44 typename SrcResetCoordinateAfterRunFlags,
45 typename DstResetCoordinateAfterRunFlags,
59 template <
typename Descs,
61 enable_if_t<Descs::Size() == Indices::Size(),
bool> =
false>
90 const SrcDescs& src_descs,
92 const DstDescs& dst_descs,
94 const ElementwiseOperation& element_op)
97 element_op_(element_op)
100 "wrong! cannot evenly divide");
103 "wrong! cannot evenly divide");
106 template <
typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(),
bool> = false>
108 const Indices& src_slice_origin_idxs)
115 template <
typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(),
bool> = false>
117 const Indices& dst_slice_origin_idxs)
124 template <
typename DataTypes, index_t ScalarPerVector>
127 auto data_types = DataTypes{};
129 constexpr index_t num = data_types.Size();
142 template <
typename SrcBuffers,
144 enable_if_t<SrcDescs::Size() == SrcBuffers::Size(),
bool> =
false>
145 __device__
void RunRead(
const SrcDescs& src_descs,
146 const SrcBuffers& src_bufs,
157 static_for<0, nSrc, 1>{}([&](
auto i) {
158 using src_vector_t =
typename remove_cvref_t<
decltype(src_vectors[i])>::type;
160 const bool is_src_valid =
164 oob_val = oob_val & is_src_valid;
166 src_vectors(i).template AsType<src_vector_t>()(
I0) =
167 src_bufs[i].
template Get<src_vector_t>(src_coords_[i].GetOffset(),
true);
170 constexpr auto get_elem_op_vec_len = []() {
173 if constexpr(
decltype(element_op_)::is_pack8_invocable)
178 if constexpr(
decltype(element_op_)::is_pack4_invocable)
183 if constexpr(
decltype(element_op_)::is_pack2_invocable)
189 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
192 static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](
auto i) {
196 [&](
auto iSrc) ->
const auto& {
199 using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
201 return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
208 [&](
auto iDst) ->
auto& {
211 using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
213 return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
225 unpack2(element_op_, dst_data_refs, src_data_refs);
228 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
229 oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
232 if constexpr(iAccess.value != src_num_access - 1)
236 static_for<0, nSrc, 1>{}([&](
auto i) {
245 static_for<0, nSrc, 1>{}([&](
auto i) {
246 if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
248 const auto src_reset_step =
257 template <index_t ThreadScratchId = 0>
262 auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
263 auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
265 static_for<0, nDst, 1>{}([&](
auto i) {
266 using elm_vector_t =
typename remove_cvref_t<
decltype(elm_vectors[i])>::type;
267 elm_vectors(i).template AsType<elm_vector_t>()(
I0) =
268 oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[
I0] : elm_vector_t{0};
271 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
276 template <index_t ThreadScratchId = 0>
282 using ElmThreadScratch =
288 using DstThreadScratch =
295 ElmThreadScratch elm_thread_scratch_;
296 DstThreadScratch dst_thread_scratch_;
298 elm_thread_scratch_.data_ =
299 bit_cast<
decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
301 if constexpr(SrcVectorDim != DstVectorDim &&
302 ((is_same<half_t, remove_cvref_t<DstData>>
::value &&
303 SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
304 (is_same<f8_t, remove_cvref_t<DstData>>
::value &&
305 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
306 (is_same<int8_t, remove_cvref_t<DstData>>
::value &&
307 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
319 detail::lambda_scalar_step_in_vector<SrcVectorDim>{},
Number<nDim>{});
322 detail::lambda_scalar_step_in_vector<DstVectorDim>{},
Number<nDim>{});
325 detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
328 DstScalarPerVector>{},
331 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
333 static_ford<
decltype(access_lengths)>{}([&](
auto access_idx) {
334 constexpr auto data_idx = access_idx * scalar_per_access;
345 [&](
auto i) ->
const src_vector_t& {
347 return elm_thread_scratch_.GetVectorTypeReference(
348 data_idx_seq + i * dst_scalar_step_in_vector);
354 [&](
auto i) -> dst_vector_t& {
356 return dst_thread_scratch_.GetVectorTypeReference(
357 data_idx_seq + i * src_scalar_step_in_vector);
362 transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
363 src_vector_refs, dst_vector_refs);
368 static_ford<SliceLengths>{}(
369 [&](
auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
377 template <
typename DstBuffers,
379 enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1,
bool> =
false>
380 __device__
void RunWrite(
const DstDescs& dst_descs,
389 auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
392 static_for<0, nDst, 1>{}([&](
auto i) {
393 using dst_vector_t =
typename remove_cvref_t<
decltype(dst_vectors[i])>::type;
395 const bool is_dst_valid =
402 dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
403 dst_coords_[i].GetOffset(),
405 dst_vectors[i].
template AsType<dst_vector_t>()[
I0]);
409 if constexpr(iAccess.value != dst_num_access - 1)
413 static_for<0, nDst, 1>{}([&](
auto i) {
421 static_for<0, nDst, 1>{}([&](
auto i) {
422 if constexpr(DstResetCoordinateAfterRunFlags::At(i))
424 const auto dst_reset_step =
436 template <
typename SrcBuffers,
438 enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
439 DstDescs::Size() == DstBuffers::Size(),
441 __device__
void Run(
const SrcDescs& src_descs,
442 const SrcBuffers& src_bufs,
443 const DstDescs& dst_descs,
452 if constexpr(src_num_access == 0)
464 if constexpr(dst_num_access == 0)
485 constexpr auto desc0 =
491 if constexpr(i == SrcVectorDim)
494 make_tuple(src_access_lengths_and_vector_length[i],
506 if constexpr(i == SrcVectorDim)
517 constexpr auto up_dim_idss =
534 constexpr auto desc0 =
540 if constexpr(i == DstVectorDim)
543 make_tuple(dst_access_lengths_and_vector_length[i],
555 if constexpr(i == DstVectorDim)
566 constexpr auto up_dim_idss =
573 template <index_t ISrc>
576 const Index& src_slice_origin_step_idx)
579 const auto adjusted_step_idx =
580 SrcResetCoordinateAfterRunFlags::At(iSrc)
581 ? src_slice_origin_step_idx
591 template <index_t IDst>
594 const Index& dst_slice_origin_step_idx)
597 const auto adjusted_step_idx =
598 DstResetCoordinateAfterRunFlags::At(iDst)
599 ? dst_slice_origin_step_idx
627 const ElementwiseOperation element_op_;
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
static __device__ auto generate_vectors()
Definition threadwise_tensor_slice_transfer_v7r2.hpp:125
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::src_scalar_per_access static constexpr auto src_scalar_per_access
Definition threadwise_tensor_slice_transfer_v7r2.hpp:73
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:89
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::SrcCoords decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray< Index, nSrc >{})) SrcCoords
Definition threadwise_tensor_slice_transfer_v7r2.hpp:68
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::nDim static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer_v7r2.hpp:51
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r2.hpp:450
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:574
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r2.hpp:462
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::DstCoords decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray< Index, nDst >{})) DstCoords
Definition threadwise_tensor_slice_transfer_v7r2.hpp:69
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r2.hpp:380
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::dst_scalar_per_access static constexpr auto dst_scalar_per_access
Definition threadwise_tensor_slice_transfer_v7r2.hpp:76
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::nDst static constexpr index_t nDst
Definition threadwise_tensor_slice_transfer_v7r2.hpp:54
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:441
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::nSrc static constexpr index_t nSrc
Definition threadwise_tensor_slice_transfer_v7r2.hpp:53
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r2.hpp:145
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:592
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::DstSpaceFillingCurve SpaceFillingCurve< decltype(thread_slice_lengths), DstDimAccessOrder, remove_cv_t< decltype(dst_scalar_per_access)>, false > DstSpaceFillingCurve
Definition threadwise_tensor_slice_transfer_v7r2.hpp:84
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r2.hpp:523
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::SrcSpaceFillingCurve SpaceFillingCurve< decltype(thread_slice_lengths), SrcDimAccessOrder, remove_cv_t< decltype(src_scalar_per_access)>, false > SrcSpaceFillingCurve
Definition threadwise_tensor_slice_transfer_v7r2.hpp:79
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::Index MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer_v7r2.hpp:56
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:107
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:116
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r2.hpp:474
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v7r2.hpp:62
__device__ void OOBCheck(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r2.hpp:258
__device__ void TransposeFromElmToDst(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r2.hpp:278
ck::ThreadwiseTensorSliceTransfer_v7r2< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, NumThreadScratch >::I0 static constexpr auto I0
Definition threadwise_tensor_slice_transfer_v7r2.hpp:49
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33