gridwise_put_element_1d.hpp Source File

gridwise_put_element_1d.hpp Source File#

Composable Kernel: gridwise_put_element_1d.hpp Source File
gridwise_put_element_1d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10
11template <typename GridwisePutElementwise1dFunctor,
12 typename InGrid1dDesc,
13 typename InDataType,
14 typename IndexDataType,
15 typename OutDataType,
16 typename ElementwiseOperation>
17__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
18 const InDataType* __restrict__ p_in_global,
19 const IndexDataType* __restrict__ p_indices_global,
20 OutDataType* __restrict__ p_out_global,
21 const ElementwiseOperation elementwise_op)
22{
23 GridwisePutElementwise1dFunctor::Run(
24 in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
25}
26
27// output[indices] = input
28template <typename InGrid1dDesc,
29 typename InDataType,
30 typename IndexDataType,
31 typename OutDataType,
32 typename ElementwiseOperation,
34 index_t InVectorSize>
36{
37 static constexpr auto I0 = Number<0>{};
38
39 static constexpr auto thread_buffer_desc_m =
41
42 __device__ static void Run(const InGrid1dDesc& in_grid_1d_desc,
43 const InDataType* __restrict__ p_in_global,
44 const IndexDataType* __restrict__ p_indices_global,
45 OutDataType* __restrict__ p_out_global,
46 const ElementwiseOperation& elementwise_op)
47 {
48 // Global Memory
49 const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
50 p_in_global, in_grid_1d_desc.GetElementSpaceSize());
51
52 const auto indices_global_buf =
54 in_grid_1d_desc.GetElementSpaceSize(),
56
57 // VGPR
60
61 // Thread id, Block id and index
62 const index_t thread_global_id = get_thread_global_1d_id();
63 const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize);
64 const index_t blockSize = get_block_size();
65 const index_t blockPerGrid = get_grid_size();
66 const auto M = in_grid_1d_desc.GetLength(I0);
67 const index_t loop_step = blockPerGrid * blockSize * InVectorSize;
68 const auto loop_step_index = make_multi_index(loop_step);
69
70 auto in_global_load =
72 InDataType,
73 decltype(in_grid_1d_desc),
74 decltype(thread_buffer_desc_m),
75 Sequence<InVectorSize>, // SliceLengths
76 Sequence<0>, // DimAccessOrder
77 0, // SrcVectorDim
78 InVectorSize, // ScalarPerVector
79 1, // SrcScalarStrideInVector
80 false>{in_grid_1d_desc, thread_global_offset};
81
82 auto indices_global_load =
84 IndexDataType,
85 decltype(in_grid_1d_desc),
86 decltype(thread_buffer_desc_m),
87 Sequence<InVectorSize>, // SliceLengths
88 Sequence<0>, // DimAccessOrder
89 0, // SrcVectorDim
90 InVectorSize, // ScalarPerVector
91 1, // SrcScalarStrideInVector
92 false>{in_grid_1d_desc, thread_global_offset};
93
94 index_t num_iter = M / loop_step;
95 do
96 {
97 in_global_load.Run(in_grid_1d_desc,
98 in_global_buf,
100 make_tuple(I0),
101 in_thread_buf);
102
103 in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
104
106 [&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); });
107
108 indices_global_load.Run(in_grid_1d_desc,
109 indices_global_buf,
111 make_tuple(I0),
112 indices_thread_buf);
113
114 indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
115
116 static_for<0, InVectorSize, 1>{}([&](auto iM) {
117 if(indices_thread_buf[iM] >= 0)
118 {
119 if constexpr(MemOp == InMemoryDataOperationEnum::Set)
120 {
121 // User should guarantee each index in p_indices_global is different
122 *(p_out_global + indices_thread_buf[iM]) =
123 ck::type_convert<OutDataType>(in_thread_buf[iM]);
124 }
125 else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd)
126 {
127 atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
128 ck::type_convert<OutDataType>(in_thread_buf[iM]));
129 }
130 else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax)
131 {
132 atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
133 ck::type_convert<OutDataType>(in_thread_buf[iM]));
134 }
135 else if constexpr(MemOp == InMemoryDataOperationEnum::Add)
136 {
137 // User should guarantee each index in p_indices_global is different
138 *(p_out_global + indices_thread_buf[iM]) +=
139 ck::type_convert<OutDataType>(in_thread_buf[iM]);
140 }
141 else
142 {
143 static_assert(MemOp == InMemoryDataOperationEnum::Set ||
147 }
148 }
149 });
150
151 } while(--num_iter);
152 }
153};
154
155} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicMax
Definition ck.hpp:280
@ AtomicAdd
Definition ck.hpp:279
@ Add
Definition ck.hpp:281
__device__ index_t get_block_size()
Definition get_id.hpp:51
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__device__ X atomic_max(X *p_dst, const X &x)
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition gridwise_put_element_1d.hpp:17
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ X atomic_add(X *p_dst, const X &x)
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_put_element_1d.hpp:36
static __device__ void Run(const InGrid1dDesc &in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation &elementwise_op)
Definition gridwise_put_element_1d.hpp:42
__host__ static __device__ constexpr T Lowest()
Definition numeric_limits.hpp:312
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition functional2.hpp:33