static_tensor.hpp Source File

static_tensor.hpp Source File#

Composable Kernel: static_tensor.hpp Source File
static_tensor.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_STATIC_TENSOR_HPP
5#define CK_STATIC_TENSOR_HPP
6
7namespace ck {
8
9// StaticTensor for Scalar
10template <AddressSpaceEnum AddressSpace,
11 typename T,
12 typename TensorDesc,
13 bool InvalidElementUseNumericalZeroValue,
14 typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
16{
17 static constexpr auto desc_ = TensorDesc{};
18 static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
19 static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
20
21 __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {}
22
23 __host__ __device__ constexpr StaticTensor(T invalid_element_value)
24 : invalid_element_scalar_value_{invalid_element_value}
25 {
26 }
27
28 // read access
29 template <typename Idx,
31 bool>::type = false>
32 __host__ __device__ constexpr const T& operator[](Idx) const
33 {
34 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
35
36 constexpr index_t offset = coord.GetOffset();
37
38 constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
39
40 if constexpr(is_valid)
41 {
42 return data_[Number<offset>{}];
43 }
44 else
45 {
46 if constexpr(InvalidElementUseNumericalZeroValue)
47 {
48 return zero_scalar_value_;
49 }
50 else
51 {
53 }
54 }
55 }
56
57 // write access
58 template <typename Idx,
60 bool>::type = false>
61 __host__ __device__ constexpr T& operator()(Idx)
62 {
63 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
64
65 constexpr index_t offset = coord.GetOffset();
66
67 constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
68
69 if constexpr(is_valid)
70 {
71 return data_(Number<offset>{});
72 }
73 else
74 {
76 }
77 }
78
80 static constexpr T zero_scalar_value_ = T{0};
83};
84
85// StaticTensor for vector
86template <AddressSpaceEnum AddressSpace,
87 typename S,
88 index_t ScalarPerVector,
89 typename TensorDesc,
90 bool InvalidElementUseNumericalZeroValue,
91 typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
93{
94 static constexpr auto desc_ = TensorDesc{};
95 static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
96 static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
97
98 static constexpr index_t num_of_vector_ =
100
102
103 __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
105 {
106 }
107
108 __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
109 : invalid_element_scalar_value_{invalid_element_value}
110 {
111 }
112
113 // Get S
114 // Idx is for S, not V
115 template <typename Idx,
116 typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
117 bool>::type = false>
118 __host__ __device__ constexpr const S& operator[](Idx) const
119 {
120 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
121
122 constexpr index_t offset = coord.GetOffset();
123
124 constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
125
126 if constexpr(is_valid)
127 {
128 return data_[Number<offset>{}];
129 }
130 else
131 {
132 if constexpr(InvalidElementUseNumericalZeroValue)
133 {
134 return zero_scalar_value_;
135 }
136 else
137 {
139 }
140 }
141 }
142
143 // Set S
144 // Idx is for S, not V
145 template <typename Idx,
146 typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
147 bool>::type = false>
148 __host__ __device__ constexpr S& operator()(Idx)
149 {
150 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
151
152 constexpr index_t offset = coord.GetOffset();
153
154 constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
155
156 if constexpr(is_valid)
157 {
158 return data_(Number<offset>{});
159 }
160 else
161 {
163 }
164 }
165
166 // Get X
167 // Idx is for S, not X. Idx should be aligned with X
168 template <typename X,
169 typename Idx,
172 bool>::type = false>
173 __host__ __device__ constexpr X GetAsType(Idx) const
174 {
175 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
176
177 constexpr index_t offset = coord.GetOffset();
178
179 constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
180
181 if constexpr(is_valid)
182 {
183 return data_.template GetAsType<X>(Number<offset>{});
184 }
185 else
186 {
187 if constexpr(InvalidElementUseNumericalZeroValue)
188 {
189 // TODO: is this right way to initialize a vector?
190 return X{0};
191 }
192 else
193 {
194 // TODO: is this right way to initialize a vector?
196 }
197 }
198 }
199
200 // Set X
201 // Idx is for S, not X. Idx should be aligned with X
202 template <typename X,
203 typename Idx,
206 bool>::type = false>
207 __host__ __device__ constexpr void SetAsType(Idx, X x)
208 {
209 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
210
211 constexpr index_t offset = coord.GetOffset();
212
213 constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
214
215 if constexpr(is_valid)
216 {
217 data_.template SetAsType<X>(Number<offset>{}, x);
218 }
219 }
220
221 // Get read access to V. No is_valid check
222 // Idx is for S, not V. Idx should be aligned with V
223 template <typename Idx>
224 __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const
225 {
226 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
227
228 constexpr index_t offset = coord.GetOffset();
229
230 return data_.GetVectorTypeReference(Number<offset>{});
231 }
232
233 // Get read access to V. No is_valid check
234 // Idx is for S, not V. Idx should be aligned with V
235 template <typename Idx>
236 __host__ __device__ constexpr V& GetVectorTypeReference(Idx)
237 {
238 constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
239
240 constexpr index_t offset = coord.GetOffset();
241
242 return data_.GetVectorTypeReference(Number<offset>{});
243 }
244
246 static constexpr S zero_scalar_value_ = S{0};
249};
250
251template <AddressSpaceEnum AddressSpace,
252 typename T,
253 typename TensorDesc,
254 typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
255__host__ __device__ constexpr auto make_static_tensor(TensorDesc)
256{
258}
259
260template <
261 AddressSpaceEnum AddressSpace,
262 typename T,
263 typename TensorDesc,
264 typename X,
265 typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false,
267__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value)
268{
269 return StaticTensor<AddressSpace, T, TensorDesc, true>{invalid_element_value};
270}
271
272} // namespace ck
273#endif
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
constexpr bool is_native_type()
Definition data_type.hpp:203
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
AddressSpaceEnum
Definition amd_address_space.hpp:15
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto make_static_tensor(TensorDesc)
Definition static_tensor.hpp:255
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:587
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition static_buffer.hpp:16
Definition static_buffer.hpp:75
Definition static_tensor.hpp:16
__host__ __device__ constexpr StaticTensor(T invalid_element_value)
Definition static_tensor.hpp:23
T ignored_element_scalar_
Definition static_tensor.hpp:82
static constexpr T zero_scalar_value_
Definition static_tensor.hpp:80
static constexpr index_t ndim_
Definition static_tensor.hpp:18
static constexpr index_t element_space_size_
Definition static_tensor.hpp:19
__host__ __device__ constexpr const T & operator[](Idx) const
Definition static_tensor.hpp:32
static constexpr auto desc_
Definition static_tensor.hpp:17
__host__ __device__ constexpr T & operator()(Idx)
Definition static_tensor.hpp:61
const T invalid_element_scalar_value_
Definition static_tensor.hpp:81
__host__ __device__ constexpr StaticTensor()
Definition static_tensor.hpp:21
StaticBuffer< AddressSpace, T, element_space_size_, true > data_
Definition static_tensor.hpp:79
__host__ __device__ constexpr const S & operator[](Idx) const
Definition static_tensor.hpp:118
__host__ __device__ constexpr S & operator()(Idx)
Definition static_tensor.hpp:148
vector_type< S, ScalarPerVector > V
Definition static_tensor.hpp:101
__host__ __device__ constexpr void SetAsType(Idx, X x)
Definition static_tensor.hpp:207
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
Definition static_tensor.hpp:108
StaticBufferTupleOfVector< AddressSpace, DstData, num_of_vector_, ScalarPerVector, true > data_
Definition static_tensor.hpp:245
__host__ __device__ constexpr V & GetVectorTypeReference(Idx)
Definition static_tensor.hpp:236
__host__ __device__ constexpr const V & GetVectorTypeReference(Idx) const
Definition static_tensor.hpp:224
__host__ __device__ constexpr X GetAsType(Idx) const
Definition static_tensor.hpp:173
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
Definition static_tensor.hpp:103
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition is_known_at_compile_time.hpp:14
Definition dtype_vector.hpp:10