check_err.hpp Source File

check_err.hpp Source File#

Composable Kernel: check_err.hpp Source File
tile/host/check_err.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <algorithm>
7#include <cmath>
8#include <cstdlib>
9#include <iostream>
10#include <iomanip>
11#include <iterator>
12#include <limits>
13#include <type_traits>
14#include <vector>
15
16#include "ck_tile/core.hpp"
18
19namespace ck_tile {
20
22constexpr int ERROR_DETAIL_LIMIT = 128;
23
33using F32 = float;
35using I8 = int8_t;
37using I32 = int32_t;
38
51template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
52CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
53{
54
55 static_assert(
57 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
58
59 double compute_error = 0;
61 {
62 return 0;
63 }
64 else
65 {
66 compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
67 }
68
70 "Warning: Unhandled OutDataType for setting up the relative threshold!");
71
72 double output_error = 0;
74 {
75 return 0;
76 }
77 else
78 {
79 output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
80 }
81 double midway_error = std::max(compute_error, output_error);
82
84 "Warning: Unhandled AccDataType for setting up the relative threshold!");
85
86 double acc_error = 0;
88 {
89 return 0;
90 }
91 else
92 {
93 acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
94 }
95 return std::max(acc_error, midway_error);
96}
97
111template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
112CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
113 const int number_of_accumulations = 1)
114{
115
116 static_assert(
118 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
119
120 auto expo = std::log2(std::abs(max_possible_num));
121 double compute_error = 0;
123 {
124 return 0;
125 }
126 else
127 {
128 compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
129 }
130
132 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
133
134 double output_error = 0;
136 {
137 return 0;
138 }
139 else
140 {
141 output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
142 }
143 double midway_error = std::max(compute_error, output_error);
144
146 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
147
148 double acc_error = 0;
150 {
151 return 0;
152 }
153 else
154 {
155 acc_error =
156 std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
157 }
158 return std::max(acc_error, midway_error);
159}
160
171template <typename T>
172std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
173{
174 using size_type = typename std::vector<T>::size_type;
175
176 os << "[";
177 for(size_type idx = 0; idx < v.size(); ++idx)
178 {
179 if(0 < idx)
180 {
181 os << ", ";
182 }
183 os << v[idx];
184 }
185 return os << "]";
186}
187
200template <typename Range, typename RefRange>
201CK_TILE_HOST bool check_size_mismatch(const Range& out,
202 const RefRange& ref,
203 const std::string& msg = "Error: Incorrect results!")
204{
205 if(out.size() != ref.size())
206 {
207 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
208 << std::endl;
209 return true;
210 }
211 return false;
212}
213
223CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
224{
225 const float error_percent =
226 static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
227 std::cerr << "max err: " << max_err;
228 std::cerr << ", number of errors: " << err_count;
229 std::cerr << ", " << error_percent << "% wrong values" << std::endl;
230}
231
248template <typename Range, typename RefRange>
249typename std::enable_if<
250 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
251 std::is_floating_point_v<ranges::range_value_t<Range>> &&
252 !std::is_same_v<ranges::range_value_t<Range>, half_t>,
253 bool>::type CK_TILE_HOST
254check_err(const Range& out,
255 const RefRange& ref,
256 const std::string& msg = "Error: Incorrect results!",
257 double rtol = 1e-5,
258 double atol = 3e-6,
259 bool allow_infinity_ref = false)
260{
261
262 if(check_size_mismatch(out, ref, msg))
263 return false;
264
265 const auto is_infinity_error = [=](auto o, auto r) {
266 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
267 const bool both_infinite_and_same =
268 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
269
270 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
271 };
272
273 bool res{true};
274 int err_count = 0;
275 double err = 0;
276 double max_err = std::numeric_limits<double>::min();
277 for(std::size_t i = 0; i < ref.size(); ++i)
278 {
279 const double o = *std::next(std::begin(out), i);
280 const double r = *std::next(std::begin(ref), i);
281 err = std::abs(o - r);
282 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
283 {
284 max_err = err > max_err ? err : max_err;
285 err_count++;
286 if(err_count < ERROR_DETAIL_LIMIT)
287 {
288 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
289 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
290 }
291 res = false;
292 }
293 }
294 if(!res)
295 {
296 report_error_stats(err_count, max_err, ref.size());
297 }
298 return res;
299}
300
317template <typename Range, typename RefRange>
318typename std::enable_if<
319 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
320 std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
321 bool>::type CK_TILE_HOST
322check_err(const Range& out,
323 const RefRange& ref,
324 const std::string& msg = "Error: Incorrect results!",
325 double rtol = 1e-3,
326 double atol = 1e-3,
327 bool allow_infinity_ref = false)
328{
329 if(check_size_mismatch(out, ref, msg))
330 return false;
331
332 const auto is_infinity_error = [=](auto o, auto r) {
333 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
334 const bool both_infinite_and_same =
335 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
336
337 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
338 };
339
340 bool res{true};
341 int err_count = 0;
342 double err = 0;
343 // TODO: This is a hack. We should have proper specialization for bf16_t data type.
344 double max_err = std::numeric_limits<float>::min();
345 for(std::size_t i = 0; i < ref.size(); ++i)
346 {
347 const double o = type_convert<float>(*std::next(std::begin(out), i));
348 const double r = type_convert<float>(*std::next(std::begin(ref), i));
349 err = std::abs(o - r);
350 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
351 {
352 max_err = err > max_err ? err : max_err;
353 err_count++;
354 if(err_count < ERROR_DETAIL_LIMIT)
355 {
356 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
357 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
358 }
359 res = false;
360 }
361 }
362 if(!res)
363 {
364 report_error_stats(err_count, max_err, ref.size());
365 }
366 return res;
367}
368
386template <typename Range, typename RefRange>
387typename std::enable_if<
388 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
389 std::is_same_v<ranges::range_value_t<Range>, half_t>,
390 bool>::type CK_TILE_HOST
391check_err(const Range& out,
392 const RefRange& ref,
393 const std::string& msg = "Error: Incorrect results!",
394 double rtol = 1e-3,
395 double atol = 1e-3,
396 bool allow_infinity_ref = false)
397{
398 if(check_size_mismatch(out, ref, msg))
399 return false;
400
401 const auto is_infinity_error = [=](auto o, auto r) {
402 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
403 const bool both_infinite_and_same =
404 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
405
406 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
407 };
408
409 bool res{true};
410 int err_count = 0;
411 double err = 0;
412 double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
413 for(std::size_t i = 0; i < ref.size(); ++i)
414 {
415 const double o = type_convert<float>(*std::next(std::begin(out), i));
416 const double r = type_convert<float>(*std::next(std::begin(ref), i));
417 err = std::abs(o - r);
418 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
419 {
420 max_err = err > max_err ? err : max_err;
421 err_count++;
422 if(err_count < ERROR_DETAIL_LIMIT)
423 {
424 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
425 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
426 }
427 res = false;
428 }
429 }
430 if(!res)
431 {
432 report_error_stats(err_count, max_err, ref.size());
433 }
434 return res;
435}
436
452template <typename Range, typename RefRange>
453std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
454 std::is_integral_v<ranges::range_value_t<Range>> &&
455 !std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
456#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
457 || std::is_same_v<ranges::range_value_t<Range>, int4_t>
458#endif
459 ,
460 bool>
461 CK_TILE_HOST check_err(const Range& out,
462 const RefRange& ref,
463 const std::string& msg = "Error: Incorrect results!",
464 double = 0,
465 double atol = 0)
466{
467 if(check_size_mismatch(out, ref, msg))
468 return false;
469
470 bool res{true};
471 int err_count = 0;
472 int64_t err = 0;
473 int64_t max_err = std::numeric_limits<int64_t>::min();
474 for(std::size_t i = 0; i < ref.size(); ++i)
475 {
476 const int64_t o = *std::next(std::begin(out), i);
477 const int64_t r = *std::next(std::begin(ref), i);
478 err = std::abs(o - r);
479
480 if(err > atol)
481 {
482 max_err = err > max_err ? err : max_err;
483 err_count++;
484 if(err_count < ERROR_DETAIL_LIMIT)
485 {
486 std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
487 << std::endl;
488 }
489 res = false;
490 }
491 }
492 if(!res)
493 {
494 report_error_stats(err_count, static_cast<double>(max_err), ref.size());
495 }
496 return res;
497}
498
516template <typename Range, typename RefRange>
517std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
518 std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
519 bool>
520 CK_TILE_HOST check_err(const Range& out,
521 const RefRange& ref,
522 const std::string& msg = "Error: Incorrect results!",
523 unsigned max_rounding_point_distance = 1,
524 double atol = 1e-1,
525 bool allow_infinity_ref = false)
526{
527 if(check_size_mismatch(out, ref, msg))
528 return false;
529
530 const auto is_infinity_error = [=](auto o, auto r) {
531 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
532 const bool both_infinite_and_same =
533 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
534
535 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
536 };
537
538 static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned {
539 static const auto get_sign_bit = [](fp8_t v) -> bool {
540 return 0x80 & bit_cast<uint8_t>(v);
541 };
542
543 if(get_sign_bit(o) ^ get_sign_bit(r))
544 {
545 return std::numeric_limits<unsigned>::max();
546 }
547 else
548 {
549 return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
550 }
551 };
552
553 bool res{true};
554 int err_count = 0;
555 double err = 0;
556 double max_err = std::numeric_limits<float>::min();
557 for(std::size_t i = 0; i < ref.size(); ++i)
558 {
559 const fp8_t o_fp8 = *std::next(std::begin(out), i);
560 const fp8_t r_fp8 = *std::next(std::begin(ref), i);
561 const double o_fp64 = type_convert<float>(o_fp8);
562 const double r_fp64 = type_convert<float>(r_fp8);
563 err = std::abs(o_fp64 - r_fp64);
564 if(!(less_equal<double>{}(err, atol) ||
565 get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
566 is_infinity_error(o_fp64, r_fp64))
567 {
568 max_err = err > max_err ? err : max_err;
569 err_count++;
570 if(err_count < ERROR_DETAIL_LIMIT)
571 {
572 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
573 << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
574 }
575 res = false;
576 }
577 }
578 if(!res)
579 {
580 report_error_stats(err_count, max_err, ref.size());
581 }
582 return res;
583}
584
601template <typename Range, typename RefRange>
602std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
603 std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
604 bool>
605 CK_TILE_HOST check_err(const Range& out,
606 const RefRange& ref,
607 const std::string& msg = "Error: Incorrect results!",
608 double rtol = 1e-3,
609 double atol = 1e-3,
610 bool allow_infinity_ref = false)
611{
612 if(check_size_mismatch(out, ref, msg))
613 return false;
614
615 const auto is_infinity_error = [=](auto o, auto r) {
616 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
617 const bool both_infinite_and_same =
618 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
619
620 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
621 };
622
623 bool res{true};
624 int err_count = 0;
625 double err = 0;
626 double max_err = std::numeric_limits<float>::min();
627 for(std::size_t i = 0; i < ref.size(); ++i)
628 {
629 const double o = type_convert<float>(*std::next(std::begin(out), i));
630 const double r = type_convert<float>(*std::next(std::begin(ref), i));
631 err = std::abs(o - r);
632 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
633 {
634 max_err = err > max_err ? err : max_err;
635 err_count++;
636 if(err_count < ERROR_DETAIL_LIMIT)
637 {
638 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
639 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
640 }
641 res = false;
642 }
643 }
644 if(!res)
645 {
646 report_error_stats(err_count, max_err, ref.size());
647 }
648 return res;
649}
650
664template <typename Range, typename RefRange>
665std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
666 std::is_same_v<ranges::range_value_t<Range>, pk_fp4_t>),
667 bool>
668 CK_TILE_HOST check_err(const Range& out,
669 const RefRange& ref,
670 const std::string& msg = "Error: Incorrect results!",
671 double = 0,
672 double = 0)
673{
674 if(check_size_mismatch(out, ref, msg))
675 return false;
676
677 int err_count = 0;
678
679 auto update_err = [&](pk_fp4_raw_t o, pk_fp4_raw_t r, std::size_t index) {
680 if(o != r)
681 {
682 std::cerr << msg << " out[" << index << "] != ref[" << index
683 << "]: " << type_convert<float>(pk_fp4_t{o})
684 << " != " << type_convert<float>(pk_fp4_t{r}) << std::endl;
685 ++err_count;
686 }
687 };
688
689 for(std::size_t i = 0; i < ref.size(); ++i)
690 {
691 const pk_fp4_t o = *std::next(std::begin(out), i);
692 const pk_fp4_t r = *std::next(std::begin(ref), i);
693 update_err(o._unpack(number<0>{}), r._unpack(number<0>{}), i * 2);
694 update_err(o._unpack(number<1>{}), r._unpack(number<1>{}), i * 2 + 1);
695 }
696 if(err_count > 0)
697 {
698 report_error_stats(err_count, numeric<pk_fp4_t>::max(), ref.size());
699 }
700 return err_count == 0;
701}
702
703} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition tile/host/ranges.hpp:37
Definition tile/core/algorithm/cluster_descriptor.hpp:13
int8_t I8
8-bit signed integer type
Definition tile/host/check_err.hpp:35
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition tile/host/check_err.hpp:31
_Float16 half_t
Definition half.hpp:111
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition tile/host/check_err.hpp:201
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition tile/host/check_err.hpp:29
float F32
32-bit floating point (single precision) type
Definition tile/host/check_err.hpp:33
int8_t int8_t
Definition int8.hpp:20
int32_t I32
32-bit signed integer type
Definition tile/host/check_err.hpp:37
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_BitInt(8) fp8_t
Definition float8.hpp:204
typename pk_fp4_t::type pk_fp4_raw_t
Definition pk_fp4.hpp:152
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition tile/host/check_err.hpp:52
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition tile/host/check_err.hpp:254
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition tile/host/check_err.hpp:223
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition tile/host/check_err.hpp:112
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition tile/host/check_err.hpp:172
pk_float4_e2m1_t pk_fp4_t
Definition pk_fp4.hpp:151
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition tile/host/check_err.hpp:27
int32_t int32_t
Definition integer.hpp:10
ck_tile::fp8_t F8
8-bit floating point type
Definition tile/host/check_err.hpp:25
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition tile/host/check_err.hpp:22
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
signed __int64 int64_t
Definition stdint.h:135
Definition type_traits.hpp:115
Definition tile/core/numeric/math.hpp:395
Definition tile/core/numeric/numeric.hpp:81
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
CK_TILE_HOST_DEVICE constexpr type _unpack(number< I >) const