warp_gemm_attribute_wmma_impl.hpp Source File

warp_gemm_attribute_wmma_impl.hpp Source File#

Composable Kernel: warp_gemm_attribute_wmma_impl.hpp Source File
warp_gemm_attribute_wmma_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// Base traits for WMMA operations
12template <typename Arch,
13 typename AType,
14 typename BType,
15 typename CType,
16 index_t M,
17 index_t N,
18 index_t K>
20
21// Generic WMMA implementation using traits
22template <typename Traits>
24{
25 using ADataType = typename Traits::ADataType;
26 using BDataType = typename Traits::BDataType;
27 using CDataType = typename Traits::CDataType;
28
29 using AVecType = typename Traits::AVecType;
30 using BVecType = typename Traits::BVecType;
31 using CVecType = typename Traits::CVecType;
32
33 // Forward all static constants and type aliases
34 static constexpr index_t kM = Traits::kM;
35 static constexpr index_t kN = Traits::kN;
36 static constexpr index_t kK = Traits::kK;
37
38 static constexpr index_t kAMBlock = Traits::kAMBlock;
39 static constexpr index_t kBNBlock = Traits::kBNBlock;
40
41 static constexpr index_t kRepeat = Traits::kRepeat;
42 static constexpr index_t kAMLane = Traits::kAMLane;
43 static constexpr index_t kBNLane = Traits::kBNLane;
44 static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
45 static constexpr index_t kABKLane = Traits::kABKLane;
46 static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
47
48 static constexpr index_t kCMLane = Traits::kCMLane;
49 static constexpr index_t kCNLane = Traits::kCNLane;
50 static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
51 static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
52
53 using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
54 using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
55 using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
56 using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
57
58 using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
59 using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
60 using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
61 using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
62
63 using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor;
64 using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor;
65 using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor;
66 using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor;
67
68 // c_vec += a_vec * b_vec
69 template <bool clamp = false, bool post_nop_ = false>
71 const AVecType& a_vec,
72 const BVecType& b_vec,
73 bool_constant<post_nop_> = {}) const
74 {
75 c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
76 }
77
78 // c_vec = a_vec * b_vec
79 template <bool clamp = false>
80 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
81 {
82 return bit_cast<CVecType>(
83 Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
84 }
85};
86
87using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
90
93
96
99
102
105
108
109template <typename Arch,
110 typename AType,
111 typename BType,
112 typename CType,
113 index_t warp_m,
114 index_t warp_n,
115 index_t warp_k>
117{
118 template <typename T>
119 static auto
120 test(int) -> decltype(std::declval<
122 ADataType>(),
123 std::true_type{});
124
125 template <typename>
126 static std::false_type test(...);
127
128 static constexpr bool value = decltype(test<Arch>(0))::value;
129};
130
131template <typename Arch,
132 typename AType,
133 typename BType,
134 typename CType,
135 index_t warp_m,
136 index_t warp_n,
137 index_t warp_k>
138constexpr bool has_wmma_traits_v =
140} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition warp_gemm_attribute_wmma_impl.hpp:138
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
WarpGemmAttributeWmmaImpl< WmmaTraits< DeviceIp, bf16_t, bf16_t, float, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16
Definition warp_gemm_attribute_wmma_impl.hpp:91
WarpGemmAttributeWmmaImpl< WmmaTraits< DeviceIp, int8_t, int8_t, int32_t, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8
Definition warp_gemm_attribute_wmma_impl.hpp:94
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
WarpGemmAttributeWmmaImpl< WmmaTraits< gfx12_t, fp8_t, bf8_t, float, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8
Definition warp_gemm_attribute_wmma_impl.hpp:103
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
remove_cvref_t< decltype(ck_tile::get_device_arch())> DeviceIp
Definition warp_gemm_attribute_wmma_impl.hpp:87
WarpGemmAttributeWmmaImpl< WmmaTraits< gfx12_t, bf8_t, fp8_t, float, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8
Definition warp_gemm_attribute_wmma_impl.hpp:106
WarpGemmAttributeWmmaImpl< WmmaTraits< gfx12_t, fp8_t, fp8_t, float, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8
Definition warp_gemm_attribute_wmma_impl.hpp:97
WarpGemmAttributeWmmaImpl< WmmaTraits< gfx12_t, bf8_t, bf8_t, float, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8
Definition warp_gemm_attribute_wmma_impl.hpp:100
WarpGemmAttributeWmmaImpl< WmmaTraits< DeviceIp, fp16_t, fp16_t, float, 16, 16, 16 > > WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16
Definition warp_gemm_attribute_wmma_impl.hpp:88
int32_t index_t
Definition integer.hpp:9
Definition warp_gemm_attribute_wmma_impl.hpp:24
static constexpr index_t kK
Definition warp_gemm_attribute_wmma_impl.hpp:36
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_wmma_impl.hpp:80
typename Traits::CVecType CVecType
Definition warp_gemm_attribute_wmma_impl.hpp:31
static constexpr index_t kAMLane
Definition warp_gemm_attribute_wmma_impl.hpp:42
typename Traits::kCPs2RHssMajor kCPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl.hpp:58
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_wmma_impl.hpp:39
typename Traits::BVecType BVecType
Definition warp_gemm_attribute_wmma_impl.hpp:30
typename Traits::kABYs2RHsMinor kABYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl.hpp:56
typename Traits::kCYs2RHsMajor kCYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl.hpp:60
typename Traits::kABPs2RHssMinor kABPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl.hpp:54
typename Traits::BDataType BDataType
Definition warp_gemm_attribute_wmma_impl.hpp:26
typename Traits::kCTYs2RHsMajor kCTYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl.hpp:65
typename Traits::AVecType AVecType
Definition warp_gemm_attribute_wmma_impl.hpp:29
typename Traits::kABPs2RHssMajor kABPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl.hpp:53
typename Traits::kCTYs2RHsMinor kCTYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl.hpp:66
typename Traits::kCYs2RHsMinor kCYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl.hpp:61
typename Traits::kCTPs2RHssMinor kCTPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl.hpp:64
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_wmma_impl.hpp:38
static constexpr index_t kM
Definition warp_gemm_attribute_wmma_impl.hpp:34
static constexpr index_t kCNLane
Definition warp_gemm_attribute_wmma_impl.hpp:49
typename Traits::kCPs2RHssMinor kCPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl.hpp:59
static constexpr index_t kABK0PerLane
Definition warp_gemm_attribute_wmma_impl.hpp:44
typename Traits::kCTPs2RHssMajor kCTPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl.hpp:63
static constexpr index_t kN
Definition warp_gemm_attribute_wmma_impl.hpp:35
static constexpr index_t kRepeat
Definition warp_gemm_attribute_wmma_impl.hpp:41
typename Traits::CDataType CDataType
Definition warp_gemm_attribute_wmma_impl.hpp:27
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_wmma_impl.hpp:50
typename Traits::kABYs2RHsMajor kABYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl.hpp:55
static constexpr index_t kBNLane
Definition warp_gemm_attribute_wmma_impl.hpp:43
typename Traits::ADataType ADataType
Definition warp_gemm_attribute_wmma_impl.hpp:25
static constexpr index_t kABK1PerLane
Definition warp_gemm_attribute_wmma_impl.hpp:46
static constexpr index_t kCMLane
Definition warp_gemm_attribute_wmma_impl.hpp:48
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_wmma_impl.hpp:70
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_wmma_impl.hpp:51
static constexpr index_t kABKLane
Definition warp_gemm_attribute_wmma_impl.hpp:45
Definition warp_gemm_attribute_wmma_impl.hpp:19
Definition warp_gemm_attribute_wmma_impl.hpp:117
static constexpr bool value
Definition warp_gemm_attribute_wmma_impl.hpp:128
static auto test(int) -> decltype(std::declval< typename WmmaTraits< T, AType, BType, CType, warp_m, warp_n, warp_k >::ADataType >(), std::true_type{})
static std::false_type test(...)