warp_gemm_attribute_wmma_impl_8bit_traits.hpp Source File

warp_gemm_attribute_wmma_impl_8bit_traits.hpp Source File#

Composable Kernel: warp_gemm_attribute_wmma_impl_8bit_traits.hpp Source File
warp_gemm_attribute_wmma_impl_8bit_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7namespace ck_tile {
8// int8 specialization - GFX11
9template <>
10struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
11 : WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
12{
13 template <bool clamp = false>
14 CK_TILE_DEVICE static CVecType
15 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
16 {
17#ifdef __gfx11__
18 return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // neg_a
20 true, // neg_b
23 clamp);
24#else
25 ck_tile::ignore = a_vec;
26 ck_tile::ignore = b_vec;
27 ck_tile::ignore = c_vec;
28 return CVecType{0};
29#endif
30 }
31};
32
33// int8 specialization - GFX12
34template <>
35struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
36 : WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
37{
38 template <bool clamp = false>
39 CK_TILE_DEVICE static CVecType
40 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
41 {
42#ifdef __gfx12__
43 return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
45 true, // neg_b
48 clamp);
49#else
50 ck_tile::ignore = a_vec;
51 ck_tile::ignore = b_vec;
52 ck_tile::ignore = c_vec;
53 return CVecType{0};
54#endif
55 }
56};
57
58// fp8/bf8 specialization - GFX12
59template <>
60struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
61 : WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
62{
63 template <bool clamp = false>
64 CK_TILE_DEVICE static CVecType
65 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
66 {
67#ifdef __gfx12__
68 return __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
70#else
71 ck_tile::ignore = a_vec;
72 ck_tile::ignore = b_vec;
73 ck_tile::ignore = c_vec;
74 return CVecType{0};
75#endif
76 }
77};
78
79template <>
80struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
81 : WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
82{
83 template <bool clamp = false>
84 CK_TILE_DEVICE static CVecType
85 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
86 {
87#ifdef __gfx12__
88 return __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
90#else
91 ck_tile::ignore = a_vec;
92 ck_tile::ignore = b_vec;
93 ck_tile::ignore = c_vec;
94 return CVecType{0};
95#endif
96 }
97};
98
99template <>
100struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
101 : WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
102{
103 template <bool clamp = false>
104 CK_TILE_DEVICE static CVecType
105 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
106 {
107#ifdef __gfx12__
108 return __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
110#else
111 ck_tile::ignore = a_vec;
112 ck_tile::ignore = b_vec;
113 ck_tile::ignore = c_vec;
114 return CVecType{0};
115#endif
116 }
117};
118
119template <>
120struct WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>
121 : WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float>
122{
123 template <bool clamp = false>
124 CK_TILE_DEVICE static CVecType
125 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
126 {
127#ifdef __gfx12__
128 return __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
130#else
131 ck_tile::ignore = a_vec;
132 ck_tile::ignore = b_vec;
133 ck_tile::ignore = c_vec;
134 return CVecType{0};
135#endif
136 }
137};
138} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
int8_t int8_t
Definition int8.hpp:20
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_HOST_DEVICE constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition tile/core/numeric/math.hpp:259
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
int32_t int32_t
Definition integer.hpp:10
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_8bit_traits.hpp:15
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_8bit_traits.hpp:85
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_8bit_traits.hpp:125
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_8bit_traits.hpp:105
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_8bit_traits.hpp:65
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_8bit_traits.hpp:40
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:7
Definition warp_gemm_attribute_wmma_impl.hpp:19
Definition arch.hpp:363
Definition arch.hpp:366