amd_inline_asm.hpp Source File

amd_inline_asm.hpp Source File#

Composable Kernel: amd_inline_asm.hpp Source File
amd_inline_asm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_AMD_INLINE_ASM_HPP
5#define CK_AMD_INLINE_ASM_HPP
6
8#include "dtype_vector.hpp"
9
10// TODO: deprecate all amd_assembly_outer_product_xxx
11
12namespace ck {
13
14inline __device__ int amd_assembly_and_b32(int a, int b)
15{
16 int c;
17 asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
18 return c;
19}
20
21inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
22{
23 int c;
24 asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
25 return c;
26}
27
29{
30 half2_t d;
31 asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c));
32 return d;
33}
34
36{
37 half2_t c;
38 asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
39 return c;
40}
41
42inline __device__ float amd_assemble_cvt_f32_i4(int b)
43{
44 float a;
45 asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b));
46 return a;
47}
48
49inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
50{
51 f8x4_t a;
52 asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n"
53 "v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
54 : "=v"(a)
55 : "v"(b0), "v"(b1), "v"(b2), "v"(b3));
56 return a;
57}
58
59inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
60{
61 uint32_t i4x8 = static_cast<uint32_t>(a);
62 uint32_t fp8x4_0;
63 uint32_t fp8x4_1;
64 float tmp_0, tmp_1, tmp_2;
65
66 asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
67 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
68 "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
69 "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
70 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
71 "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
72 "v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
73 "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n"
74 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
75 "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
76 "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
77 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
78 "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
79 : [v_tmp_0] "+v"(tmp_0),
80 [v_tmp_1] "+v"(tmp_1),
81 [v_tmp_2] "+v"(tmp_2),
82 [v_dst_0] "+v"(fp8x4_0),
83 [v_dst_1] "+v"(fp8x4_1),
84 [v_src] "+v"(i4x8)
85 :);
86
87 return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
88}
89
90// c0 += inner_product(a, b0)
91// c1 += inner_product(a, b1)
92__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
93{
94 asm volatile("\n \
95 v_fmac_f32 %0, %2, %3 \n \
96 v_fmac_f32 %1, %2, %4 \n \
97 "
98 : "=v"(c0), "=v"(c1)
99 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
100}
101
102// c0 += inner_product(a, b0)
103// c1 += inner_product(a, b1)
104// c2 += inner_product(a, b2)
105// c3 += inner_product(a, b3)
107 float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
108{
109 asm volatile("\n \
110 v_fmac_f32 %0, %4, %5 \n \
111 v_fmac_f32 %1, %4, %6 \n \
112 v_fmac_f32 %2, %4, %7 \n \
113 v_fmac_f32 %3, %4, %8 \n \
114 "
115 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
116 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
117}
118
119// c0 += inner_product(a, b0)
120// c1 += inner_product(a, b1)
121__device__ void
123{
124 asm volatile("\n \
125 v_dot2_f32_f16 %0, %2, %3, %0\n \
126 v_dot2_f32_f16 %1, %2, %4, %1\n \
127 "
128 : "=v"(c0), "=v"(c1)
129 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
130}
131
132// c0 += inner_product(a, b0)
133// c1 += inner_product(a, b1)
134__device__ void
135amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
136{
137 // TODO remove pointer casting
139 const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
140 const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
141
142 // do dot2 two times
143 asm volatile("\n \
144 v_dot2_f32_f16 %0, %2, %4, %0\n \
145 v_dot2_f32_f16 %1, %2, %6, %1\n \
146 v_dot2_f32_f16 %0, %3, %5, %0\n \
147 v_dot2_f32_f16 %1, %3, %7, %1\n \
148 "
149 : "=v"(c0), "=v"(c1)
150 : "v"(p_a_half2[0]),
151 "v"(p_a_half2[1]),
152 "v"(p_b0_half2[0]),
153 "v"(p_b0_half2[1]),
154 "v"(p_b1_half2[0]),
155 "v"(p_b1_half2[1]),
156 "0"(c0),
157 "1"(c1));
158}
159
160// c0 += inner_product(a, b0)
161// c1 += inner_product(a, b1)
162// c2 += inner_product(a, b2)
163// c3 += inner_product(a, b3)
165 half2_t b0,
166 half2_t b1,
167 half2_t b2,
168 half2_t b3,
169 float& c0,
170 float& c1,
171 float& c2,
172 float& c3)
173{
174 asm volatile("\n \
175 v_dot2_f32_f16 %0, %4, %5, %0\n \
176 v_dot2_f32_f16 %1, %4, %6, %1\n \
177 v_dot2_f32_f16 %2, %4, %7, %2\n \
178 v_dot2_f32_f16 %3, %4, %8, %3\n \
179 "
180 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
181 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
182}
183
184// c0 += inner_product(a, b0)
185// c1 += inner_product(a, b1)
186// c2 += inner_product(a, b2)
187// c3 += inner_product(a, b3)
189 half4_t b0,
190 half4_t b1,
191 half4_t b2,
192 half4_t b3,
193 float& c0,
194 float& c1,
195 float& c2,
196 float& c3)
197{
198 // TODO remove pointer casting
200 const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
201 const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
202 const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
203 const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
204
205 // do dot2 two times
206 asm volatile("\n \
207 v_dot2_f32_f16 %0, %4, %6, %0\n \
208 v_dot2_f32_f16 %1, %4, %8, %1\n \
209 v_dot2_f32_f16 %2, %4, %10, %2\n \
210 v_dot2_f32_f16 %3, %4, %12, %3\n \
211 v_dot2_f32_f16 %0, %5, %7, %0\n \
212 v_dot2_f32_f16 %1, %5, %9, %1\n \
213 v_dot2_f32_f16 %2, %5, %11, %2\n \
214 v_dot2_f32_f16 %3, %5, %13, %3\n \
215 "
216 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
217 : "v"(p_a_half2[0]),
218 "v"(p_a_half2[1]),
219 "v"(p_b0_half2[0]),
220 "v"(p_b0_half2[1]),
221 "v"(p_b1_half2[0]),
222 "v"(p_b1_half2[1]),
223 "v"(p_b2_half2[0]),
224 "v"(p_b2_half2[1]),
225 "v"(p_b3_half2[0]),
226 "v"(p_b3_half2[1]),
227 "0"(c0),
228 "1"(c1),
229 "2"(c2),
230 "3"(c3));
231}
232
234 half8_t b0,
235 half8_t b1,
236 half8_t b2,
237 half8_t b3,
238 float& c0,
239 float& c1,
240 float& c2,
241 float& c3)
242{
243
244 // TODO remove pointer casting
246 const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
247 const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
248 const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
249 const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
250
252 p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
253
255 p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
256}
257
259 half16_t b0,
260 half16_t b1,
261 half16_t b2,
262 half16_t b3,
263 float& c0,
264 float& c1,
265 float& c2,
266 float& c3)
267{
268 // TODO remove pointer casting
270 const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
271 const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
272 const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
273 const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
274
276 p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
277
279 p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
280}
281
282// c0 += inner_product(a, b0)
283// c1 += inner_product(a, b1)
284__device__ void
286{
287#if 1
288 asm volatile("\n \
289 v_dot4_i32_i8 %0, %2, %3, %0\n \
290 v_dot4_i32_i8 %1, %2, %4, %1\n \
291 "
292 : "=v"(c0), "=v"(c1)
293 : "v"(bit_cast<int32_t>(a)),
294 "v"(bit_cast<int32_t>(b0)),
295 "v"(bit_cast<int32_t>(b1)),
296 "0"(c0),
297 "1"(c1));
298#else
299 c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
300 c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
301#endif
302}
303
304// c0 += inner_product(a, b0)
305// c1 += inner_product(a, b1)
306// c2 += inner_product(a, b2)
307// c3 += inner_product(a, b3)
309 int8x4_t b0,
310 int8x4_t b1,
311 int8x4_t b2,
312 int8x4_t b3,
313 int32_t& c0,
314 int32_t& c1,
315 int32_t& c2,
316 int32_t& c3)
317{
318#if 1
319 asm volatile("\n \
320 v_dot4_i32_i8 %0, %4, %5, %0\n \
321 v_dot4_i32_i8 %1, %4, %6, %1\n \
322 v_dot4_i32_i8 %2, %4, %7, %2\n \
323 v_dot4_i32_i8 %3, %4, %8, %3\n \
324 "
325 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
326 : "v"(bit_cast<int32_t>(a)),
327 "v"(bit_cast<int32_t>(b0)),
328 "v"(bit_cast<int32_t>(b1)),
329 "v"(bit_cast<int32_t>(b2)),
330 "v"(bit_cast<int32_t>(b3)),
331 "0"(c0),
332 "1"(c1),
333 "2"(c2),
334 "3"(c3));
335#else
336 c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
337 c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
338 c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
339 c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
340#endif
341}
342
344 int8x8_t b0,
345 int8x8_t b1,
346 int8x8_t b2,
347 int8x8_t b3,
348 int32_t& c0,
349 int32_t& c1,
350 int32_t& c2,
351 int32_t& c3)
352{
353 constexpr auto I0 = Number<0>{};
354 constexpr auto I1 = Number<1>{};
355
357 vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
358 vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
359 vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
360 vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
361 c0,
362 c1,
363 c2,
364 c3);
365
367 vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
368 vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
369 vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
370 vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
371 c0,
372 c1,
373 c2,
374 c3);
375}
376
378 int8x16_t b0,
379 int8x16_t b1,
380 int8x16_t b2,
381 int8x16_t b3,
382 int32_t& c0,
383 int32_t& c1,
384 int32_t& c2,
385 int32_t& c3)
386
387{
388 constexpr auto I0 = Number<0>{};
389 constexpr auto I1 = Number<1>{};
390 constexpr auto I2 = Number<2>{};
391 constexpr auto I3 = Number<3>{};
392
394 vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
395 vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
396 vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
397 vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
398 c0,
399 c1,
400 c2,
401 c3);
402
404 vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
405 vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
406 vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
407 vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
408 c0,
409 c1,
410 c2,
411 c3);
412
414 vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
415 vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
416 vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
417 vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
418 c0,
419 c1,
420 c2,
421 c3);
422
424 vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
425 vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
426 vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
427 vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
428 c0,
429 c1,
430 c2,
431 c3);
432}
433
434} // namespace ck
435#endif
Definition ck.hpp:268
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition amd_inline_asm.hpp:35
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
__device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
Definition amd_inline_asm.hpp:59
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition amd_inline_asm.hpp:106
typename vector_type< half_t, 16 >::type half16_t
Definition dtype_vector.hpp:2156
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
__device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
Definition amd_inline_asm.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ int amd_assembly_and_b32(int a, int b)
Definition amd_inline_asm.hpp:14
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition amd_inline_asm.hpp:28
__host__ __device__ PY c_style_pointer_cast(PX p_x)
Definition c_style_pointer_cast.hpp:15
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float &c0, float &c1)
Definition amd_inline_asm.hpp:92
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition amd_inline_asm.hpp:21
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ float amd_assemble_cvt_f32_i4(int b)
Definition amd_inline_asm.hpp:42
typename vector_type< half_t, 4 >::type half4_t
Definition dtype_vector.hpp:2154
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
unsigned __int64 uint64_t
Definition stdint.h:136
Definition dtype_vector.hpp:10