block_flatmm_asmem_bsmem_creg_v1.hpp Source File

block_flatmm_asmem_bsmem_creg_v1.hpp Source File#

Composable Kernel: block_flatmm_asmem_bsmem_creg_v1.hpp Source File
block_flatmm_asmem_bsmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block window on shared memory
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename BlockPolicy_>
16{
23
24 static constexpr auto I0 = number<0>();
25 static constexpr auto I1 = number<1>();
26 static constexpr auto I2 = number<2>();
27 static constexpr auto idxM = I0;
28 static constexpr auto idxN = I1;
29 static constexpr auto idxK = I2;
33
34 static constexpr index_t kBlockSize = Problem::kBlockSize;
35
36 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
37 {
38 constexpr index_t MPerBlock = BlockGemmShape::kM;
39 constexpr index_t NPerBlock = BlockGemmShape::kN;
40
41 constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
42
43 using WG = remove_cvref_t<decltype(config.template at<0>())>;
44
45 constexpr index_t MWarp = config.template at<1>();
46 constexpr index_t NWarp = config.template at<2>();
47
48 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
49 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
50
51 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
58
59 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
60 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
61
62 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
63
64 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
65 return c_block_tensor;
66 }
67
68 // C += A * B
69 template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
70 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
71 ABlockWindow& a_warp_windows,
72 BFlatBlockTensor& b_warp_tensor) const
73 {
74 constexpr index_t MPerBlock = BlockGemmShape::kM;
75 constexpr index_t KPerBlock = BlockGemmShape::kK;
76
77 constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
78 using WG = remove_cvref_t<decltype(config.template at<0>())>;
79
80 constexpr index_t MWarp = config.template at<1>();
81
82 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
83 constexpr index_t NIterPerWarp =
84 BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
85 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
86
87 using CWarpDstr = typename WG::CWarpDstr;
88 using CWarpTensor = typename WG::CWarpTensor;
89
90 constexpr auto c_warp_y_lengths =
91 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
92 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
93
94 // hot loop:
95 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
96 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
97 // read A warp tensor from A block window
98 const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
99
100 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
101 // read C warp tensor from C block tensor
102 CWarpTensor c_warp_tensor;
103
104 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
105 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
106 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
107
108 // warp GEMM
109 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
110
111 // write C warp tensor into C block tensor
112 c_block_tensor.set_y_sliced_thread_data(
113 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
114 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
115 c_warp_tensor.get_thread_buffer());
116 __builtin_amdgcn_sched_barrier(0x7F6);
117 });
118 });
119 });
120 }
121};
122
123} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:16
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:30
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:21
static constexpr index_t kBlockSize
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:34
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:32
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:19
remove_cvref_t< Problem_ > Problem
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:17
static constexpr auto I0
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:24
static constexpr auto idxM
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:27
static constexpr auto idxK
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:29
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:31
remove_cvref_t< BlockPolicy_ > BlockPolicy
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:18
static constexpr auto I2
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:26
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:36
static constexpr auto idxN
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:28
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:20
static constexpr auto I1
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:25
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:22
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, ABlockWindow &a_warp_windows, BFlatBlockTensor &b_warp_tensor) const
Definition block_flatmm_asmem_bsmem_creg_v1.hpp:70
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192