device_grouped_conv_bwd_weight.hpp Source File

device_grouped_conv_bwd_weight.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight.hpp Source File
device_grouped_conv_bwd_weight.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
15
16template <ck::index_t NDimSpatial,
17 typename InLayout,
18 typename WeiLayout,
19 typename OutLayout,
20 typename InDataType,
21 typename WeiDataType,
22 typename OutDataType,
23 typename InElementwiseOperation,
24 typename WeiElementwiseOperation,
25 typename OutElementwiseOperation,
26 typename ComputeTypeA = InDataType,
27 typename ComputeTypeB = ComputeTypeA>
29{
30 virtual std::unique_ptr<BaseArgument>
31 MakeArgumentPointer(const void* p_in,
32 void* p_wei,
33 const void* p_out,
34 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
35 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
36 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
37 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
38 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
39 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
40 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
41 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
42 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
43 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
44 InElementwiseOperation in_element_op,
45 WeiElementwiseOperation wei_element_op,
46 OutElementwiseOperation out_element_op,
47 ck::index_t split_k) = 0;
48
49 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
50};
51
52} // namespace device
53} // namespace tensor_operation
54} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_conv_bwd_weight.hpp:29
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, void *p_wei, const void *p_out, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)=0