device_batched_gemm_softmax_gemm_permute.hpp Source File

device_batched_gemm_softmax_gemm_permute.hpp Source File#

Composable Kernel: device_batched_gemm_softmax_gemm_permute.hpp Source File
device_batched_gemm_softmax_gemm_permute.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16template <index_t NumDimG,
17 index_t NumDimM,
18 index_t NumDimN,
19 index_t NumDimK,
20 index_t NumDimO,
21 typename ADataType,
22 typename B0DataType,
23 typename B1DataType,
24 typename CDataType,
25 typename Acc0BiasDataType,
26 typename Acc1BiasDataType,
27 typename AElementwiseOperation,
28 typename B0ElementwiseOperation,
29 typename C0DEElementwiseOperation,
30 typename B1ElementwiseOperation,
31 typename C1DEElementwiseOperation,
32 MaskingSpecialization MaskingSpec>
34{
35 static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
36 static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
37
38 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
39 const void* p_a,
40 const void* p_b0,
41 const void* p_b1,
42 void* p_c,
43 const std::array<void*, NumAcc0Bias> p_acc0_biases,
44 const std::array<void*, NumAcc1Bias> p_acc1_biases,
45 const std::vector<index_t>& a_gs_ms_ks_lengths,
46 const std::vector<index_t>& a_gs_ms_ks_strides,
47 const std::vector<index_t>& b_gs_ns_ks_lengths,
48 const std::vector<index_t>& b_gs_ns_ks_strides,
49 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
50 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
51 const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
52 const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
53 const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
54 const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
55 const std::array<std::vector<index_t>, NumAcc1Bias>
56 acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
57 const std::array<std::vector<index_t>, NumAcc1Bias>
58 acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
59 AElementwiseOperation a_element_op,
60 B0ElementwiseOperation b0_element_op,
61 C0DEElementwiseOperation c0de_element_op,
62 B1ElementwiseOperation b1_element_op,
63 C1DEElementwiseOperation c1de_element_op) = 0;
64
65 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
66};
67
68} // namespace device
69} // namespace tensor_operation
70} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
MaskingSpecialization
Definition masking_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
static constexpr index_t NumAcc1Bias
Definition device_batched_gemm_softmax_gemm_permute.hpp:36
static constexpr index_t NumAcc0Bias
Definition device_batched_gemm_softmax_gemm_permute.hpp:35
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< index_t >, NumAcc0Bias > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumAcc0Bias > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< index_t >, NumAcc1Bias > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< index_t >, NumAcc1Bias > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0