grouped_convolution_forward_kernel.hpp Source File#
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
Definition tile/ops/common/tensor_layout.hpp:27
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_specialization.hpp:14
@ Filter3x3
Definition convolution_specialization.hpp:15
@ Filter1x1Pad0
Definition convolution_specialization.hpp:13
GroupedConvHostArgs< const void *, const void *, void *, CDElementwise > GroupedConvFwdHostArgs
Definition grouped_convolution_utils.hpp:50
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition grouped_convolution_forward_kernel.hpp:384
index_t w_size
Definition grouped_convolution_forward_kernel.hpp:388
index_t h_start
Definition grouped_convolution_forward_kernel.hpp:387
index_t w_start
Definition grouped_convolution_forward_kernel.hpp:387
index_t d_size
Definition grouped_convolution_forward_kernel.hpp:388
index_t h_size
Definition grouped_convolution_forward_kernel.hpp:388
index_t block_start
Definition grouped_convolution_forward_kernel.hpp:385
index_t block_end
Definition grouped_convolution_forward_kernel.hpp:386
index_t d_start
Definition grouped_convolution_forward_kernel.hpp:387
Definition grouped_convolution_forward_kernel.hpp:376
index_t num_d_pieces
Definition grouped_convolution_forward_kernel.hpp:380
index_t total_w
Definition grouped_convolution_forward_kernel.hpp:378
index_t total_d
Definition grouped_convolution_forward_kernel.hpp:378
std::array< PieceInfo, MaxPieces > pieces
Definition grouped_convolution_forward_kernel.hpp:392
static constexpr index_t MaxPieces
Definition grouped_convolution_forward_kernel.hpp:391
index_t total_spatial
Definition grouped_convolution_forward_kernel.hpp:379
index_t num_w_pieces
Definition grouped_convolution_forward_kernel.hpp:380
index_t total_h
Definition grouped_convolution_forward_kernel.hpp:378
index_t num_h_pieces
Definition grouped_convolution_forward_kernel.hpp:380
The Grouped Convolution kernel device arguments.
Definition grouped_convolution_forward_kernel.hpp:24
long_index_t group_stride_c
Definition grouped_convolution_forward_kernel.hpp:353
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType_::InLayout >())> AGridDescMK
Definition grouped_convolution_forward_kernel.hpp:315
index_t input_batch_stride
Definition grouped_convolution_forward_kernel.hpp:359
static constexpr index_t NonSpatialDims
Definition grouped_convolution_forward_kernel.hpp:325
index_t n_per_split
Definition grouped_convolution_forward_kernel.hpp:357
const CDElementwise elfunc
Definition grouped_convolution_forward_kernel.hpp:344
AGridDescMK a_grid_desc_m_k
Definition grouped_convolution_forward_kernel.hpp:347
TransformConvFwdToGemm< GroupedConvTraitsType_::NDimSpatial, GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, GroupedConvTraitsType_::VectorSizeC, GroupedConvTraitsType_::NumGroupsToMerge, true > ConvToGemmFwdTransformer
Definition grouped_convolution_forward_kernel.hpp:26
CGridDescMN CGridDescMN_t
Definition grouped_convolution_forward_kernel.hpp:372
const void * in_ptr
Definition grouped_convolution_forward_kernel.hpp:341
index_t GemmM
Definition grouped_convolution_forward_kernel.hpp:336
index_t original_n
Definition grouped_convolution_forward_kernel.hpp:358
long_index_t group_stride_b
Definition grouped_convolution_forward_kernel.hpp:352
ConvToGemmFwdTransformer ConvToGemmFwdTransformer_t
Definition grouped_convolution_forward_kernel.hpp:370
CGridDescMN c_grid_desc_m_n
Definition grouped_convolution_forward_kernel.hpp:349
CDElementwise_ CDElementwise
Definition grouped_convolution_forward_kernel.hpp:34
index_t n_splits
Definition grouped_convolution_forward_kernel.hpp:356
std::array< const void *, NumDTensor > ds_ptr
Definition grouped_convolution_forward_kernel.hpp:343
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition grouped_convolution_forward_kernel.hpp:332
AGridDescMK AGridDescMK_t
Definition grouped_convolution_forward_kernel.hpp:371
const void * wei_ptr
Definition grouped_convolution_forward_kernel.hpp:342
BGridDescNK b_grid_desc_n_k
Definition grouped_convolution_forward_kernel.hpp:348
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType_::WeiLayout >())> BGridDescNK
Definition grouped_convolution_forward_kernel.hpp:318
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType_::OutLayout >())> CGridDescMN
Definition grouped_convolution_forward_kernel.hpp:321
index_t num_spatial_pieces
Definition grouped_convolution_forward_kernel.hpp:395
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition grouped_convolution_forward_kernel.hpp:328
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition grouped_convolution_forward_kernel.hpp:327
index_t GemmN
Definition grouped_convolution_forward_kernel.hpp:337
long_index_t spatial_offset_in
Definition grouped_convolution_forward_kernel.hpp:363
SplitImageInfo split_image
Definition grouped_convolution_forward_kernel.hpp:396
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &args)
Definition grouped_convolution_forward_kernel.hpp:45
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition grouped_convolution_forward_kernel.hpp:333
index_t output_batch_stride
Definition grouped_convolution_forward_kernel.hpp:360
long_index_t group_stride_a
Definition grouped_convolution_forward_kernel.hpp:351
index_t GemmK
Definition grouped_convolution_forward_kernel.hpp:338
void * out_ptr
Definition grouped_convolution_forward_kernel.hpp:345
ConvToGemmFwdTransformer transformer_
Definition grouped_convolution_forward_kernel.hpp:367
index_t GemmBatch
Definition grouped_convolution_forward_kernel.hpp:339
long_index_t spatial_offset_out
Definition grouped_convolution_forward_kernel.hpp:364
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition grouped_convolution_forward_kernel.hpp:326
static constexpr index_t NumDTensor
Definition grouped_convolution_forward_kernel.hpp:35
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition grouped_convolution_forward_kernel.hpp:331
index_t k_batch
Definition grouped_convolution_forward_kernel.hpp:335
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition grouped_convolution_forward_kernel.hpp:330
const std::vector< const void * > ds_ptr
Definition grouped_convolution_utils.hpp:41
Definition grouped_convolution_forward_kernel.hpp:490
index_t h
Definition grouped_convolution_forward_kernel.hpp:491
index_t d
Definition grouped_convolution_forward_kernel.hpp:491
index_t w
Definition grouped_convolution_forward_kernel.hpp:491
The Grouped Convolution Forward kernel template.
Definition grouped_convolution_forward_kernel.hpp:442
static CK_TILE_DEVICE index_t FindPieceId(index_t block_id, const SplitImageInfo &split_info, index_t num_pieces)
Definition grouped_convolution_forward_kernel.hpp:536
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition grouped_convolution_forward_kernel.hpp:459
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_convolution_forward_kernel.hpp:448
static CK_TILE_HOST constexpr GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &hostArgs)
Definition grouped_convolution_forward_kernel.hpp:583
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_convolution_forward_kernel.hpp:447
typename EpiloguePipeline::CDElementwise CDElementwise
Definition grouped_convolution_forward_kernel.hpp:470
static constexpr auto I1
Definition grouped_convolution_forward_kernel.hpp:478
static constexpr auto I2
Definition grouped_convolution_forward_kernel.hpp:479
static CK_TILE_DEVICE index_t FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
Definition grouped_convolution_forward_kernel.hpp:517
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_forward_kernel.hpp:860
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition grouped_convolution_forward_kernel.hpp:456
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition grouped_convolution_forward_kernel.hpp:589
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc)
Definition grouped_convolution_forward_kernel.hpp:723
static constexpr auto I0
Definition grouped_convolution_forward_kernel.hpp:477
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition grouped_convolution_forward_kernel.hpp:954
static constexpr bool EnableSplitImage
Definition grouped_convolution_forward_kernel.hpp:443
GroupedConvFwdKernelArgs< GroupedConvTraitsType_, CDElementwise > GroupedConvFwdKernelArgsSpecialized
Definition grouped_convolution_forward_kernel.hpp:472
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition grouped_convolution_forward_kernel.hpp:455
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition grouped_convolution_forward_kernel.hpp:468
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition grouped_convolution_forward_kernel.hpp:457
static constexpr index_t kBlockSize
Definition grouped_convolution_forward_kernel.hpp:462
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_convolution_forward_kernel.hpp:466
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition grouped_convolution_forward_kernel.hpp:451
static constexpr index_t NDimSpatial
Definition grouped_convolution_forward_kernel.hpp:444
static CK_TILE_HOST auto BlockSize()
Definition grouped_convolution_forward_kernel.hpp:577
static constexpr auto I3
Definition grouped_convolution_forward_kernel.hpp:480
static CK_TILE_HOST const std::string GetName()
Definition grouped_convolution_forward_kernel.hpp:559
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition grouped_convolution_forward_kernel.hpp:594
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition grouped_convolution_forward_kernel.hpp:465
static constexpr index_t NumDTensor
Definition grouped_convolution_forward_kernel.hpp:460
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition grouped_convolution_forward_kernel.hpp:805
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition grouped_convolution_forward_kernel.hpp:450
static constexpr bool IsSplitKSupported
Definition grouped_convolution_forward_kernel.hpp:475
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition grouped_convolution_forward_kernel.hpp:452
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_forward_kernel.hpp:917
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition grouped_convolution_forward_kernel.hpp:764
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition grouped_convolution_forward_kernel.hpp:454
static constexpr ConvolutionSpecialization ConvSpecialization
Definition grouped_convolution_forward_kernel.hpp:445
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_convolution_forward_kernel.hpp:449
static CK_TILE_DEVICE SpatialCoords UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
Definition grouped_convolution_forward_kernel.hpp:496
static CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition grouped_convolution_forward_kernel.hpp:571
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition grouped_convolution_forward_kernel.hpp:464
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:28
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition tile/host/convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition tile/host/convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition tile/host/convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition tile/host/convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition tile/host/convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition tile/host/convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition tile/host/convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition tile/host/convolution_parameter.hpp:134
Definition type_traits.hpp:115
Definition tile/core/container/sequence.hpp:49