device_batched_gemm_softmax_gemm.hpp Source File

device_batched_gemm_softmax_gemm.hpp Source File#

Composable Kernel: device_batched_gemm_softmax_gemm.hpp Source File
device_batched_gemm_softmax_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#ifndef __HIPCC_RTC__
6#include <iostream>
7#include <vector>
8#endif
9
10#include "device_base.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16template <typename ALayout,
17 typename B0Layout,
18 typename B1Layout,
19 typename CLayout,
20 typename ADataType,
21 typename B0DataType,
22 typename B1DataType,
23 typename CDataType,
24 typename AElementwiseOperation,
25 typename B0ElementwiseOperation,
26 typename Acc0ElementwiseOperation,
27 typename B1ElementwiseOperation,
28 typename CElementwiseOperation,
29 bool MaskOutUpperTriangle> // TODO: enum for mask type
31{
32#ifndef __HIPCC_RTC__
33 virtual std::unique_ptr<BaseArgument>
34 MakeArgumentPointer(const void* p_a,
35 const void* p_b0,
36 const void* p_b1,
37 void* p_c,
42 ck::index_t Batch,
43 ck::index_t StrideA,
44 ck::index_t StrideB0,
45 ck::index_t StrideB1,
46 ck::index_t StrideC,
47 ck::index_t BatchStrideA,
48 ck::index_t BatchStrideB0,
49 ck::index_t BatchStrideB1,
50 ck::index_t BatchStrideC,
51 AElementwiseOperation a_element_op,
52 B0ElementwiseOperation b0_element_op,
53 Acc0ElementwiseOperation acc0_element_op,
54 B1ElementwiseOperation b1_element_op,
55 CElementwiseOperation c_element_op) = 0;
56
57 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
58#endif
59};
60
61} // namespace device
62} // namespace tensor_operation
63} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_gemm_softmax_gemm.hpp:31
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t O, ck::index_t Batch, ck::index_t StrideA, ck::index_t StrideB0, ck::index_t StrideB1, ck::index_t StrideC, ck::index_t BatchStrideA, ck::index_t BatchStrideB0, ck::index_t BatchStrideB1, ck::index_t BatchStrideC, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, Acc0ElementwiseOperation acc0_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)=0