image_to_column_kernel.hpp Source File

image_to_column_kernel.hpp Source File#

Composable Kernel: image_to_column_kernel.hpp Source File
image_to_column_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename Problem_>
13{
14 static constexpr auto I0 = number<0>{};
15 static constexpr auto I1 = number<1>{};
16 static constexpr auto I2 = number<2>{};
17 static constexpr auto I3 = number<3>{};
18 static constexpr auto I4 = number<4>{};
19
21
24
25 static constexpr index_t NDimSpatial = Problem::NDimSpatial;
26
27 static constexpr index_t AligmentIn = Problem::AligmentIn;
28 static constexpr index_t AligmentOut = Problem::AligmentOut;
29
30 static_assert(NDimSpatial == 2, "Not supported.");
31
32 static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
33 static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
34 static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
35
55
56 CK_TILE_HOST static constexpr Kargs
57 MakeKargs(const void* p_in,
58 void* p_out,
59 const long_index_t G,
60 const long_index_t N,
61 const long_index_t C,
62 const array<long_index_t, NDimSpatial> input_spatial_lengths,
63 const array<long_index_t, NDimSpatial> filter_spatial_lengths,
64 const array<long_index_t, NDimSpatial> output_spatial_lengths,
65 const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
66 const array<long_index_t, 3> gemm_g_m_k_strides,
67 const array<long_index_t, NDimSpatial> conv_filter_strides,
68 const array<long_index_t, NDimSpatial> conv_filter_dilations,
69 const array<long_index_t, NDimSpatial> input_left_pads,
70 const array<long_index_t, NDimSpatial> input_right_pads)
71 {
72 return Kargs{p_in,
73 p_out,
74 G,
75 N,
76 C,
77 input_spatial_lengths,
78 filter_spatial_lengths,
79 output_spatial_lengths,
80 image_g_n_c_wis_strides,
81 gemm_g_m_k_strides,
82 conv_filter_strides,
83 conv_filter_dilations,
84 input_left_pads,
85 input_right_pads};
86 }
87
88 CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
89 {
90 return dim3(
92 }
93
94 CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
95
96 CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
97 {
98 static_assert(NDimSpatial == 2, "Not supported.");
99
100 const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
102 kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
108 I1);
109
110 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
111 in_n_hi_wi_c_desc,
114 kargs.input_left_pads[I0],
115 kargs.input_right_pads[I0]),
117 kargs.input_left_pads[I1],
118 kargs.input_right_pads[I1]),
122
123 const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
124 in_n_hip_wip_c_desc,
136
138 in_n_y_ho_x_wo_c_desc,
141 kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
143 kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
146 }
147
148 CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
149 {
150 static_assert(NDimSpatial == 2, "Not supported.");
151 const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
153 const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
155 return make_tuple(M, K);
156 }
157
159 {
160 using P = typename Problem::BlockShape;
161 // P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
162 // Y: {kMPerThread, kKPerThread}
171 sequence<2, 2>>{});
172 }
173
174 CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
175 {
176 const auto [M, K] = CalculateMKDims(kargs);
177
178 const index_t iM = amd_wave_read_first_lane(blockIdx.x * kMPerBlock);
179 const index_t iK = amd_wave_read_first_lane(blockIdx.y * kKPerBlock);
180 const index_t iBatch = amd_wave_read_first_lane(blockIdx.z);
181
182 const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
183 const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
184
186 static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
188 static_cast<OutDataType*>(kargs.p_out) + out_offset,
189 make_tuple(M, K),
192 I1);
193
194 const auto image_m_k_padded =
195 pad_tensor_view(image_m_k,
198 const auto gemm_m_k_padded =
199 pad_tensor_view(gemm_m_k,
202
203 constexpr auto dstr = MakeBlockTileDistribution();
204
205 const auto image_tile =
206 make_tile_window(image_m_k_padded,
208 {iM, iK},
209 dstr);
210
211 auto gemm_tile = make_tile_window(gemm_m_k_padded,
213 {iM, iK},
214 dstr);
215
216 // load from Global
217 const auto loaded_tile = load_tile(image_tile);
218 // save to Global
219 store_tile(gemm_tile, loaded_tile);
220 }
221
222 CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
223};
224
225} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
Definition image_to_column_kernel.hpp:37
const array< long_index_t, NDimSpatial > input_right_pads
Definition image_to_column_kernel.hpp:53
const array< long_index_t, NDimSpatial > conv_filter_strides
Definition image_to_column_kernel.hpp:50
const array< long_index_t, NDimSpatial > conv_filter_dilations
Definition image_to_column_kernel.hpp:51
const long_index_t C
Definition image_to_column_kernel.hpp:43
const long_index_t N
Definition image_to_column_kernel.hpp:42
const long_index_t G
Definition image_to_column_kernel.hpp:41
const array< long_index_t, NDimSpatial+3 > image_g_n_c_wis_strides
Definition image_to_column_kernel.hpp:48
const void * p_in
Definition image_to_column_kernel.hpp:38
const array< long_index_t, NDimSpatial > filter_spatial_lengths
Definition image_to_column_kernel.hpp:46
void * p_out
Definition image_to_column_kernel.hpp:39
const array< long_index_t, NDimSpatial > input_left_pads
Definition image_to_column_kernel.hpp:52
const array< long_index_t, 3 > gemm_g_m_k_strides
Definition image_to_column_kernel.hpp:49
const array< long_index_t, NDimSpatial > input_spatial_lengths
Definition image_to_column_kernel.hpp:45
const array< long_index_t, NDimSpatial > output_spatial_lengths
Definition image_to_column_kernel.hpp:47
Definition image_to_column_kernel.hpp:13
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs &kargs) const
Definition image_to_column_kernel.hpp:96
static constexpr auto I2
Definition image_to_column_kernel.hpp:16
static CK_TILE_HOST constexpr auto BlockSize()
Definition image_to_column_kernel.hpp:94
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs &kargs) const
Definition image_to_column_kernel.hpp:174
static constexpr index_t kBlockSize
Definition image_to_column_kernel.hpp:34
static constexpr auto I4
Definition image_to_column_kernel.hpp:18
static constexpr auto I1
Definition image_to_column_kernel.hpp:15
remove_cvref_t< typename Problem::InDataType > InDataType
Definition image_to_column_kernel.hpp:22
remove_cvref_t< Problem_ > Problem
Definition image_to_column_kernel.hpp:20
static constexpr index_t AligmentOut
Definition image_to_column_kernel.hpp:28
static constexpr index_t AligmentIn
Definition image_to_column_kernel.hpp:27
static CK_TILE_HOST constexpr Kargs MakeKargs(const void *p_in, void *p_out, const long_index_t G, const long_index_t N, const long_index_t C, const array< long_index_t, NDimSpatial > input_spatial_lengths, const array< long_index_t, NDimSpatial > filter_spatial_lengths, const array< long_index_t, NDimSpatial > output_spatial_lengths, const array< long_index_t, NDimSpatial+3 > image_g_n_c_wis_strides, const array< long_index_t, 3 > gemm_g_m_k_strides, const array< long_index_t, NDimSpatial > conv_filter_strides, const array< long_index_t, NDimSpatial > conv_filter_dilations, const array< long_index_t, NDimSpatial > input_left_pads, const array< long_index_t, NDimSpatial > input_right_pads)
Definition image_to_column_kernel.hpp:57
static CK_TILE_HOST constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
Definition image_to_column_kernel.hpp:88
CK_TILE_DEVICE void operator()(Kargs &kargs) const
Definition image_to_column_kernel.hpp:222
static constexpr auto I3
Definition image_to_column_kernel.hpp:17
static constexpr index_t kKPerBlock
Definition image_to_column_kernel.hpp:33
remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition image_to_column_kernel.hpp:23
CK_TILE_DEVICE auto CalculateMKDims(const Kargs &kargs) const
Definition image_to_column_kernel.hpp:148
static constexpr index_t kMPerBlock
Definition image_to_column_kernel.hpp:32
static constexpr index_t NDimSpatial
Definition image_to_column_kernel.hpp:25
static CK_TILE_DEVICE constexpr auto MakeBlockTileDistribution()
Definition image_to_column_kernel.hpp:158
static constexpr auto I0
Definition image_to_column_kernel.hpp:14
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192