mxf6_utils.hpp Source File

mxf6_utils.hpp Source File#

Composable Kernel: mxf6_utils.hpp Source File
mxf6_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_CODE_GEN_RTC
5#pragma once
6
9
10namespace ck::utils {
11
22template <>
23__host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
24 f6_t const dataBytes [[maybe_unused]])
25{
26 // no need to check for data as it does not have NaN representation
27 return scale.is_nan();
28}
29
40template <>
41__host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
42 bf6_t const dataBytes [[maybe_unused]])
43{
44 // no need to check for data as it does not have NaN representation
45 return scale.is_nan();
46}
47
57template <>
58__host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
59 f6_t const data [[maybe_unused]])
60{
61 // no inf representation for fp6
62 return false;
63}
64
74template <>
75__host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
76 bf6_t const data [[maybe_unused]])
77{
78 // no inf representation for bf6
79 return false;
80}
81
93template <>
94__host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
95{
96 if(is_nan<f6_t>(scale, data))
97 return false;
98
99 // no need to check for scale as it does not have a 0 representation
100 f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
101
102 return result == 0b0;
103}
104
116template <>
117__host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
118{
119 if(is_nan<bf6_t>(scale, data))
120 return false;
121
122 // no need to check for scale as it does not have a 0 representation
123 bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
124
125 return result == 0b0;
126}
127
138template <>
139__host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
140{
141 if(is_nan<f6_t>(scale, data))
143
144 if(is_zero<f6_t>(scale, data))
145 return 0.0f;
146
147 f6_t prepared_data = data & 0b00111111;
148
149 int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
150
151 return convert_to_float<f6_t>(prepared_data, scale_exp);
152}
153
164template <>
165__host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
166{
167 if(is_nan<bf6_t>(scale, data))
169
170 if(is_zero<bf6_t>(scale, data))
171 return 0.0f;
172
173 bf6_t prepared_data = data & 0b00111111;
174
175 int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
176
177 return convert_to_float<bf6_t>(prepared_data, scale_exp);
178}
179
190template <>
191__host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
192{
193 cvt t;
194 t.value_float = value;
195 uint32_t sign = t.value_bitwise >> 31;
196
197 if(std::isnan(value))
198 {
199
202 }
203
204 if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
207
209
214
215 return res;
216}
217
228template <>
255
266template <>
290
301template <>
324} // namespace ck::utils
325#endif
Definition library/utility/check_err.hpp:24
__host__ __device__ f6_t sat_convert_to_type_sr< f6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition mxf6_utils.hpp:267
__host__ __device__ bool is_zero< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Checks whether an bf6_t value is zero.
Definition mxf6_utils.hpp:117
__host__ __device__ bf6_t sat_convert_to_type_sr< bf6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition mxf6_utils.hpp:302
__host__ __device__ f6_t sat_convert_to_type< f6_t >(float value)
Converts a float to f6_t with saturation.
Definition mxf6_utils.hpp:191
__host__ __device__ float convert_to_float(T data, int scale_exp)
Definition mxfp_utils.hpp:73
__host__ __device__ T convert_to_type_sr(float value, uint32_t seed)
Definition mxfp_utils.hpp:261
__host__ __device__ bf6_t sat_convert_to_type< bf6_t >(float value)
Converts a float to bf6_t with saturation.
Definition mxf6_utils.hpp:229
__host__ __device__ bool is_nan< f6_t >(e8m0_bexp_t const scale, f6_t const dataBytes)
Checks if an f6_t value is NaN based on the provided scale.
Definition mxf6_utils.hpp:23
__host__ __device__ float to_float< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
Definition mxf6_utils.hpp:165
__host__ __device__ bool is_inf< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Checks if an f6_t value is infinite.
Definition mxf6_utils.hpp:58
__host__ __device__ bool is_zero(e8m0_bexp_t const scale, T const data)
__host__ __device__ T convert_to_type(float value)
Definition mxfp_utils.hpp:102
__host__ __device__ bool is_inf< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Checks if an bf6_t value is infinite.
Definition mxf6_utils.hpp:75
__host__ __device__ bool is_nan(e8m0_bexp_t const scale, T const data)
__host__ __device__ bool is_nan< bf6_t >(e8m0_bexp_t const scale, bf6_t const dataBytes)
Checks if an bf6_t value is NaN based on the provided scale.
Definition mxf6_utils.hpp:41
__host__ __device__ float to_float< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
Definition mxf6_utils.hpp:139
__host__ __device__ float to_float(e8m0_bexp_t const scale, T const data)
__host__ __device__ constexpr int32_t get_exponent_value< e8m0_bexp_t >(e8m0_bexp_t x)
Definition utility/e8m0.hpp:74
__host__ __device__ bool is_zero< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Checks whether an f6_t value is zero.
Definition mxf6_utils.hpp:94
_BitInt(6) f6_t
Definition data_type.hpp:34
unsigned _BitInt(6) bf6_t
Definition data_type.hpp:35
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
Definition numeric_limits.hpp:309
__host__ static __device__ constexpr T QuietNaN()
Definition numeric_limits.hpp:313
Definition numeric_utils.hpp:10
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
__host__ __device__ constexpr bool is_nan() const
Definition utility/e8m0.hpp:65
Definition mxfp_utils.hpp:14
float value_float
Definition mxfp_utils.hpp:15
uint32_t value_bitwise
Definition mxfp_utils.hpp:16