transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp Source File

transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp Source File#

Composable Kernel: transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp Source File
transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
5#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
6
7#include "common_header.hpp"
8#include "tensor_descriptor.hpp"
10
11namespace ck {
12
13// A: in
14// B: wei
15// C: out
16// GemmM = N * Do * Ho * Wo
17// GemmN = K
18// GemmK = Z * Y * X * C
19template <typename... In,
20 typename... Wei,
21 typename... Out,
22 typename ConvStrides,
23 typename ConvDilations,
24 typename InLeftPads,
25 typename InRightPads,
26 index_t GemmK1Value>
27__host__ __device__ constexpr auto
29 const TensorDescriptor<In...>& in_grid_desc_n_di_hi_wi_c,
30 const TensorDescriptor<Wei...>& wei_k_z_y_x_c_grid_desc,
31 const TensorDescriptor<Out...>& out_n_do_ho_wo_k_grid_desc,
32 const ConvStrides& conv_strides,
33 const ConvDilations& conv_dilations,
34 const InLeftPads& in_left_pads,
35 const InRightPads& in_right_pads,
37{
38 constexpr auto I0 = Number<0>{};
39 constexpr auto I1 = Number<1>{};
40 constexpr auto I2 = Number<2>{};
41 constexpr auto I3 = Number<3>{};
42 constexpr auto I4 = Number<4>{};
43
44 constexpr auto GemmK1 = Number<GemmK1Value>{};
45
46 const auto N = in_grid_desc_n_di_hi_wi_c.GetLength(I0);
47 const auto K = out_n_do_ho_wo_k_grid_desc.GetLength(I4);
48 const auto C = in_grid_desc_n_di_hi_wi_c.GetLength(I4);
49
50 const auto Di = in_grid_desc_n_di_hi_wi_c.GetLength(I1);
51 const auto Hi = in_grid_desc_n_di_hi_wi_c.GetLength(I2);
52 const auto Wi = in_grid_desc_n_di_hi_wi_c.GetLength(I3);
53
54 const auto Do = out_n_do_ho_wo_k_grid_desc.GetLength(I1);
55 const auto Ho = out_n_do_ho_wo_k_grid_desc.GetLength(I2);
56 const auto Wo = out_n_do_ho_wo_k_grid_desc.GetLength(I3);
57
58 const auto Z = wei_k_z_y_x_c_grid_desc.GetLength(I1);
59 const auto Y = wei_k_z_y_x_c_grid_desc.GetLength(I2);
60 const auto X = wei_k_z_y_x_c_grid_desc.GetLength(I3);
61
62 const auto ConvStrideD = conv_strides[I0];
63 const auto ConvStrideH = conv_strides[I1];
64 const auto ConvStrideW = conv_strides[I2];
65
66 const auto ConvDilationD = conv_dilations[I0];
67 const auto ConvDilationH = conv_dilations[I1];
68 const auto ConvDilationW = conv_dilations[I2];
69
70 const auto InLeftPadD = in_left_pads[I0];
71 const auto InLeftPadH = in_left_pads[I1];
72 const auto InLeftPadW = in_left_pads[I2];
73
74 const auto InRightPadD = in_right_pads[I0];
75 const auto InRightPadH = in_right_pads[I1];
76 const auto InRightPadW = in_right_pads[I2];
77
78 const auto GemmM = N * Do * Ho * Wo;
79 const auto GemmN = K;
80 const auto GemmK = Z * Y * X * C;
81 const auto GemmK0 = GemmK / GemmK1;
82
83 // A: input tensor
84 const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
85 in_grid_desc_n_di_hi_wi_c,
87 make_pad_transform(Di, InLeftPadD, InRightPadD),
88 make_pad_transform(Hi, InLeftPadH, InRightPadH),
89 make_pad_transform(Wi, InLeftPadW, InRightPadW),
93
94 const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
95 in_grid_desc_n_dip_hip_wip_c,
97 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
98 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
99 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
104
105 const auto in_grid_desc_gemmk_gemmm =
106 transform_tensor_descriptor(in_grid_desc_n_z_do_y_ho_x_wo_c,
108 make_merge_transform(make_tuple(N, Do, Ho, Wo))),
111
112 const auto in_grid_desc_gemmk0_gemmm_gemmk1 =
113 transform_tensor_descriptor(in_grid_desc_gemmk_gemmm,
118
119 // B: weight tensor
120 const auto wei_grid_desc_gemmk_gemmn = transform_tensor_descriptor(
125
126 const auto wei_grid_desc_gemmk0_gemmn_gemmk1 =
127 transform_tensor_descriptor(wei_grid_desc_gemmk_gemmn,
132
133 // C: output tensor
134 const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor(
139
140 // const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor(
141 // out_n_do_ho_wo_k_grid_desc,
142 // make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
143 // make_pass_through_transform(K)),
144 // make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<3>{}),
145 // make_tuple(Sequence<0>{}, Sequence<1>{}));
146
147 return make_tuple(in_grid_desc_gemmk0_gemmm_gemmk1,
148 wei_grid_desc_gemmk0_gemmn_gemmk1,
149 out_grid_desc_gemmm_gemmn);
150}
151
152} // namespace ck
153#endif
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(const TensorDescriptor< In... > &in_grid_desc_n_di_hi_wi_c, const TensorDescriptor< Wei... > &wei_k_z_y_x_c_grid_desc, const TensorDescriptor< Out... > &out_n_do_ho_wo_k_grid_desc, const ConvStrides &conv_strides, const ConvDilations &conv_dilations, const InLeftPads &in_left_pads, const InRightPads &in_right_pads, Number< GemmK1Value >)
Definition transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp:28
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
Definition tensor_description/tensor_descriptor.hpp:28
__host__ __device__ constexpr auto GetLength(Number< IDim >) const
Definition tensor_description/tensor_descriptor.hpp:147