transpose_vectors.hpp Source File

transpose_vectors.hpp Source File#

Composable Kernel: transpose_vectors.hpp Source File
utility/transpose_vectors.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
8#include "data_type.hpp"
9
10namespace ck {
11
12template <typename S,
13 index_t NX,
14 index_t NY,
15 typename enable_if<is_scalar_type<S>::value, bool>::type = false>
17
18// transpose fp16 2x2
19__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
20{
21#if 0
22 static constexpr auto I0 = Number<0>{};
23 static constexpr auto I1 = Number<1>{};
24
25 const vector_type<half_t, 2> vx0{x0}, vx1{x1};
27
28 vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
29 vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
30
31 vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
32 vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
33
34 y0 = vy0.template AsType<half2_t>()[I0];
35 y1 = vy1.template AsType<half2_t>()[I0];
36#else
37 constexpr int32_t m0 = 0x05040100;
38 constexpr int32_t m1 = 0x07060302;
39
40 // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
41 // -- -- -- -- -- -- -- -- - - - -
42 // index 7 6 5 4 3 2 1 0 33 77 44 88
43 // index is reversed because of little endianness (least significant bits first)
44 y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
45 y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
46#endif
47}
48
49template <index_t NX, index_t NY>
51{
52 // we got [NY * NX] amount of S data to be transposed
53 static constexpr index_t s_per_x = NY;
54 static constexpr index_t s_per_y = NX;
55
56 using S = half_t;
59
60 __device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
62 {
63 static constexpr auto I1 = Number<1>{};
64 static constexpr auto I2 = Number<2>{};
65
66 static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
67
68 // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
69 static_for<0, NY, 2>{}([&](auto iy) {
70 static_for<0, NX, 2>{}([&](auto ix) {
71 // reference to 2 half2_t data from vx_tuple
72 const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
73 const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
74
75 // reference to 2 half2_t data from vy_tuple
76 auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
77 auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
78
79 // transpose
80 transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
81 });
82 });
83 }
84};
85
86// transpose int8 4x4
87__device__ void transpose_int8_4x4(const int8x4_t& x0,
88 const int8x4_t& x1,
89 const int8x4_t& x2,
90 const int8x4_t& x3,
91 int8x4_t& y0,
92 int8x4_t& y1,
93 int8x4_t& y2,
94 int8x4_t& y3)
95{
96 int32_t t0, t1;
97 int32_t z0, z1, z2, z3;
98 constexpr int32_t m0 = 0x05010400;
99 constexpr int32_t m1 = 0x05040100;
100 constexpr int32_t m2 = 0x07060302;
101 constexpr int32_t m3 = 0x07030602;
102
103 // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
104 // -- -- -- -- -- -- -- -- - - - -
105 // index 7 6 5 4 3 2 1 0 33 77 44 88
106 // index is reversed because of little endianness (least significant bits first)
107 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
108 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
109 z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
110 z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
111 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
112 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
113 z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
114 z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
115
116 y0 = bit_cast<int8x4_t>(z0);
117 y1 = bit_cast<int8x4_t>(z1);
118 y2 = bit_cast<int8x4_t>(z2);
119 y3 = bit_cast<int8x4_t>(z3);
120}
121
122template <index_t NX, index_t NY>
124{
125 // we got [NY * NX] amount of S data to be transposed
126 static constexpr index_t s_per_x = NY;
127 static constexpr index_t s_per_y = NX;
128
129 using S = int8_t;
132
133 __device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
135 {
136 static constexpr auto I1 = Number<1>{};
137 static constexpr auto I2 = Number<2>{};
138 static constexpr auto I3 = Number<3>{};
139 static constexpr auto I4 = Number<4>{};
140
141 static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
142
143 // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
144 static_for<0, NY, 4>{}([&](auto iy) {
145 static_for<0, NX, 4>{}([&](auto ix) {
146 // reference to 4 int8 data from vx_tuple
147 const auto& x_s4_0 = vx_tuple[ix].template AsType<int8x4_t>()[iy / I4];
148 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<int8x4_t>()[iy / I4];
149 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<int8x4_t>()[iy / I4];
150 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<int8x4_t>()[iy / I4];
151
152 // reference to 4 int8 data from vy_tuple
153 auto& y_s4_0 = vy_tuple(iy).template AsType<int8x4_t>()(ix / I4);
154 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<int8x4_t>()(ix / I4);
155 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<int8x4_t>()(ix / I4);
156 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<int8x4_t>()(ix / I4);
157
158 // transpose
159 transpose_int8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
160 });
161 });
162 }
163};
164
165// transpose f8 4x4
166__device__ void transpose_f8_4x4(const f8x4_t& x0,
167 const f8x4_t& x1,
168 const f8x4_t& x2,
169 const f8x4_t& x3,
170 f8x4_t& y0,
171 f8x4_t& y1,
172 f8x4_t& y2,
173 f8x4_t& y3)
174{
175 int32_t t0, t1;
176 int32_t z0, z1, z2, z3;
177 constexpr int32_t m0 = 0x05010400;
178 constexpr int32_t m1 = 0x05040100;
179 constexpr int32_t m2 = 0x07060302;
180 constexpr int32_t m3 = 0x07030602;
181
182 // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
183 // -- -- -- -- -- -- -- -- - - - -
184 // index 7 6 5 4 3 2 1 0 33 77 44 88
185 // index is reversed because of little endianness (least significant bits first)
186 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
187 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
188 z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
189 z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
190 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
191 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
192 z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
193 z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
194
195 y0 = bit_cast<f8x4_t>(z0);
196 y1 = bit_cast<f8x4_t>(z1);
197 y2 = bit_cast<f8x4_t>(z2);
198 y3 = bit_cast<f8x4_t>(z3);
199}
200
201template <index_t NX, index_t NY>
202struct transpose_vectors<f8_t, NX, NY>
203{
204 // we got [NY * NX] amount of S data to be transposed
205 static constexpr index_t s_per_x = NY;
206 static constexpr index_t s_per_y = NX;
207
208 using S = f8_t;
211
212 __device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
214 {
215 static constexpr auto I1 = Number<1>{};
216 static constexpr auto I2 = Number<2>{};
217 static constexpr auto I3 = Number<3>{};
218 static constexpr auto I4 = Number<4>{};
219
220 static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
221
222 // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
223 static_for<0, NY, 4>{}([&](auto iy) {
224 static_for<0, NX, 4>{}([&](auto ix) {
225 // reference to 4 f8 data from vx_tuple
226 const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
227 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
228 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
229 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
230
231 // reference to 4 f8 data from vy_tuple
232 auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
233 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
234 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
235 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
236
237 // transpose
238 transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
239 });
240 });
241 }
242};
243
244} // namespace ck
Definition ck.hpp:268
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ void transpose_f8_4x4(const f8x4_t &x0, const f8x4_t &x1, const f8x4_t &x2, const f8x4_t &x3, f8x4_t &y0, f8x4_t &y1, f8x4_t &y2, f8x4_t &y3)
Definition utility/transpose_vectors.hpp:166
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__device__ void transpose_int8_4x4(const int8x4_t &x0, const int8x4_t &x1, const int8x4_t &x2, const int8x4_t &x3, int8x4_t &y0, int8x4_t &y1, int8x4_t &y2, int8x4_t &y3)
Definition utility/transpose_vectors.hpp:87
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void transpose_fp16_2x2(const half2_t &x0, const half2_t &x1, half2_t &y0, half2_t &y1)
Definition utility/transpose_vectors.hpp:19
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Definition functional2.hpp:33
vector_type< f8_t, s_per_x > VX
Definition utility/transpose_vectors.hpp:209
f8_t S
Definition utility/transpose_vectors.hpp:208
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition utility/transpose_vectors.hpp:212
vector_type< f8_t, s_per_y > VY
Definition utility/transpose_vectors.hpp:210
static constexpr index_t s_per_x
Definition utility/transpose_vectors.hpp:205
static constexpr index_t s_per_y
Definition utility/transpose_vectors.hpp:206
vector_type< half_t, s_per_x > VX
Definition utility/transpose_vectors.hpp:57
static constexpr index_t s_per_x
Definition utility/transpose_vectors.hpp:53
static constexpr index_t s_per_y
Definition utility/transpose_vectors.hpp:54
half_t S
Definition utility/transpose_vectors.hpp:56
vector_type< half_t, s_per_y > VY
Definition utility/transpose_vectors.hpp:58
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition utility/transpose_vectors.hpp:60
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition utility/transpose_vectors.hpp:133
vector_type< int8_t, s_per_y > VY
Definition utility/transpose_vectors.hpp:131
static constexpr index_t s_per_x
Definition utility/transpose_vectors.hpp:126
int8_t S
Definition utility/transpose_vectors.hpp:129
static constexpr index_t s_per_y
Definition utility/transpose_vectors.hpp:127
vector_type< int8_t, s_per_x > VX
Definition utility/transpose_vectors.hpp:130
Definition utility/transpose_vectors.hpp:16
Definition dtype_vector.hpp:10