device_gemm_mx.hpp Source File

device_gemm_mx.hpp Source File#

Composable Kernel: device_gemm_mx.hpp Source File
device_gemm_mx.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9namespace tensor_operation {
10namespace device {
11
12template <typename ALayout,
13 typename BLayout,
14 typename CLayout,
15 typename ADataType,
16 typename AScaleDataType,
17 typename BDataType,
18 typename BScaleDataType,
19 typename CDataType,
20 index_t ScaleBlockSize,
21 typename AElementwiseOperation,
22 typename BElementwiseOperation,
23 typename CElementwiseOperation>
25{
26 virtual std::unique_ptr<BaseArgument>
27 MakeArgumentPointer(const void* p_a,
28 const void* p_a_scale,
29 const void* p_b,
30 const void* p_b_scale,
31 void* p_c,
35 ck::index_t StrideA,
36 ck::index_t StrideAScale,
37 ck::index_t StrideB,
38 ck::index_t StrideBScale,
39 ck::index_t StrideC,
40 ck::index_t KBatch,
41 AElementwiseOperation a_element_op,
42 BElementwiseOperation b_element_op,
43 CElementwiseOperation c_element_op) = 0;
44
45 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
46};
47
48template <typename ALayout,
49 typename BLayout,
50 typename CLayout,
51 typename ADataType,
52 typename AScaleDataType,
53 typename BDataType,
54 typename BScaleDataType,
55 typename CDataType,
56 index_t ScaleBlockSize,
57 typename AElementwiseOperation,
58 typename BElementwiseOperation,
59 typename CElementwiseOperation>
61{
62 virtual std::unique_ptr<BaseArgument>
63 MakeArgumentPointer(const void* p_a,
64 const void* p_a_scale,
65 const void* p_b,
66 const void* p_b_scale,
67 void* p_c,
71 ck::index_t StrideA,
72 ck::index_t StrideAScale,
73 ck::index_t StrideB,
74 ck::index_t StrideBScale,
75 ck::index_t StrideC,
76 ck::index_t KBatch,
77 AElementwiseOperation a_element_op,
78 BElementwiseOperation b_element_op,
79 CElementwiseOperation c_element_op) = 0;
80
81 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
82
83 virtual int GetPreShuffleParameters() = 0;
84};
85
86} // namespace device
87} // namespace tensor_operation
88} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition device_gemm_mx.hpp:25
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0