51template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
56 is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
57 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
59 double compute_error = 0;
69 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
70 "Warning: Unhandled OutDataType for setting up the relative threshold!");
72 double output_error = 0;
81 double midway_error = std::max(compute_error, output_error);
83 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
84 "Warning: Unhandled AccDataType for setting up the relative threshold!");
95 return std::max(acc_error, midway_error);
111template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
113 const int number_of_accumulations = 1)
117 is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
118 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
120 auto expo = std::log2(std::abs(max_possible_num));
121 double compute_error = 0;
131 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
132 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
134 double output_error = 0;
143 double midway_error = std::max(compute_error, output_error);
145 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
146 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
148 double acc_error = 0;
158 return std::max(acc_error, midway_error);
172std::ostream&
operator<<(std::ostream& os,
const std::vector<T>& v)
174 using size_type =
typename std::vector<T>::size_type;
177 for(size_type idx = 0; idx < v.size(); ++idx)
200template <
typename Range,
typename RefRange>
203 const std::string& msg =
"Error: Incorrect results!")
205 if(out.size() != ref.size())
207 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
225 const float error_percent =
226 static_cast<float>(err_count) /
static_cast<float>(total_size) * 100.f;
227 std::cerr <<
"max err: " << max_err;
228 std::cerr <<
", number of errors: " << err_count;
229 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
248template <
typename Range,
typename RefRange>
249typename std::enable_if<
251 std::is_floating_point_v<ranges::range_value_t<Range>> &&
252 !std::is_same_v<ranges::range_value_t<Range>,
half_t>,
256 const std::string& msg =
"Error: Incorrect results!",
259 bool allow_infinity_ref =
false)
265 const auto is_infinity_error = [=](
auto o,
auto r) {
266 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
267 const bool both_infinite_and_same =
270 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
276 double max_err = std::numeric_limits<double>::min();
277 for(std::size_t i = 0; i < ref.size(); ++i)
279 const double o = *std::next(std::begin(out), i);
280 const double r = *std::next(std::begin(ref), i);
281 err = std::abs(o - r);
282 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
284 max_err = err > max_err ? err : max_err;
288 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
289 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
317template <
typename Range,
typename RefRange>
318typename std::enable_if<
320 std::is_same_v<ranges::range_value_t<Range>,
bf16_t>,
324 const std::string& msg =
"Error: Incorrect results!",
327 bool allow_infinity_ref =
false)
332 const auto is_infinity_error = [=](
auto o,
auto r) {
333 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
334 const bool both_infinite_and_same =
337 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
344 double max_err = std::numeric_limits<float>::min();
345 for(std::size_t i = 0; i < ref.size(); ++i)
349 err = std::abs(o - r);
350 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
352 max_err = err > max_err ? err : max_err;
356 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
357 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
386template <
typename Range,
typename RefRange>
387typename std::enable_if<
389 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
393 const std::string& msg =
"Error: Incorrect results!",
396 bool allow_infinity_ref =
false)
401 const auto is_infinity_error = [=](
auto o,
auto r) {
402 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
403 const bool both_infinite_and_same =
406 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
412 double max_err =
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
413 for(std::size_t i = 0; i < ref.size(); ++i)
417 err = std::abs(o - r);
418 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
420 max_err = err > max_err ? err : max_err;
424 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
425 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
452template <
typename Range,
typename RefRange>
454 std::is_integral_v<ranges::range_value_t<Range>> &&
455 !std::is_same_v<ranges::range_value_t<Range>,
bf16_t>)
456#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
463 const std::string& msg =
"Error: Incorrect results!",
473 int64_t max_err = std::numeric_limits<int64_t>::min();
474 for(std::size_t i = 0; i < ref.size(); ++i)
476 const int64_t o = *std::next(std::begin(out), i);
477 const int64_t r = *std::next(std::begin(ref), i);
478 err = std::abs(o - r);
482 max_err = err > max_err ? err : max_err;
486 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
516template <
typename Range,
typename RefRange>
518 std::is_same_v<ranges::range_value_t<Range>,
fp8_t>),
522 const std::string& msg =
"Error: Incorrect results!",
523 unsigned max_rounding_point_distance = 1,
525 bool allow_infinity_ref =
false)
530 const auto is_infinity_error = [=](
auto o,
auto r) {
531 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
532 const bool both_infinite_and_same =
535 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
538 static const auto get_rounding_point_distance = [](
fp8_t o,
fp8_t r) ->
unsigned {
539 static const auto get_sign_bit = [](
fp8_t v) ->
bool {
543 if(get_sign_bit(o) ^ get_sign_bit(r))
545 return std::numeric_limits<unsigned>::max();
556 double max_err = std::numeric_limits<float>::min();
557 for(std::size_t i = 0; i < ref.size(); ++i)
559 const fp8_t o_fp8 = *std::next(std::begin(out), i);
560 const fp8_t r_fp8 = *std::next(std::begin(ref), i);
563 err = std::abs(o_fp64 - r_fp64);
565 get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
566 is_infinity_error(o_fp64, r_fp64))
568 max_err = err > max_err ? err : max_err;
572 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
573 <<
"] != ref[" << i <<
"]: " << o_fp64 <<
" != " << r_fp64 << std::endl;
601template <
typename Range,
typename RefRange>
603 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
607 const std::string& msg =
"Error: Incorrect results!",
610 bool allow_infinity_ref =
false)
615 const auto is_infinity_error = [=](
auto o,
auto r) {
616 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
617 const bool both_infinite_and_same =
620 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
626 double max_err = std::numeric_limits<float>::min();
627 for(std::size_t i = 0; i < ref.size(); ++i)
631 err = std::abs(o - r);
632 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
634 max_err = err > max_err ? err : max_err;
638 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
639 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
664template <
typename Range,
typename RefRange>
666 std::is_same_v<ranges::range_value_t<Range>,
pk_fp4_t>),
670 const std::string& msg =
"Error: Incorrect results!",
682 std::cerr << msg <<
" out[" << index <<
"] != ref[" << index
689 for(std::size_t i = 0; i < ref.size(); ++i)
691 const pk_fp4_t o = *std::next(std::begin(out), i);
692 const pk_fp4_t r = *std::next(std::begin(ref), i);
700 return err_count == 0;
#define CK_TILE_HOST
Definition config.hpp:40
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition tile/host/ranges.hpp:37
Definition tile/core/algorithm/cluster_descriptor.hpp:13
int8_t I8
8-bit signed integer type
Definition tile/host/check_err.hpp:35
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition tile/host/check_err.hpp:31
_Float16 half_t
Definition half.hpp:111
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition tile/host/check_err.hpp:201
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition tile/host/check_err.hpp:29
float F32
32-bit floating point (single precision) type
Definition tile/host/check_err.hpp:33
int8_t int8_t
Definition int8.hpp:20
int32_t I32
32-bit signed integer type
Definition tile/host/check_err.hpp:37
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_BitInt(8) fp8_t
Definition float8.hpp:204
typename pk_fp4_t::type pk_fp4_raw_t
Definition pk_fp4.hpp:152
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition tile/host/check_err.hpp:52
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition tile/host/check_err.hpp:254
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition tile/host/check_err.hpp:223
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition tile/host/check_err.hpp:112
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition tile/host/check_err.hpp:172
pk_float4_e2m1_t pk_fp4_t
Definition pk_fp4.hpp:151
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition tile/host/check_err.hpp:27
int32_t int32_t
Definition integer.hpp:10
ck_tile::fp8_t F8
8-bit floating point type
Definition tile/host/check_err.hpp:25
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition tile/host/check_err.hpp:22
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
signed __int64 int64_t
Definition stdint.h:135
Definition type_traits.hpp:115
Definition tile/core/numeric/math.hpp:395
Definition tile/core/numeric/numeric.hpp:81
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
CK_TILE_HOST_DEVICE constexpr type _unpack(number< I >) const