unary_element_wise_operation.hpp Source File

unary_element_wise_operation.hpp Source File#

Composable Kernel: unary_element_wise_operation.hpp Source File
tensor_operation/gpu/element/unary_element_wise_operation.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
11#include <cassert>
12
13namespace ck {
14
15// Fast int4x4 to half8_t data type conversion based on paper
16// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
17// (https://arxiv.org/abs/2211.10017) and implementation:
18// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
19// Convert lower part of packed int4 -> int4 to half
20__device__ inline half4_t i4_to_half4(int q)
21{
22 const int LO = 0x000f000f;
23 const int HI = 0x00f000f0;
24 const int EX = 0x64006400;
25
26 // Extract the two int4 at low bit and create two fp16 number.
27 int lo = amd_assembly_and_or_b32(q, LO, EX);
28 // Extract the two int4 at hight bit and create two fp16 number.
29 int hi = amd_assembly_and_or_b32(q, HI, EX);
30
31 const int SUB = 0xE408E408; // half2 {-1032, -1032}
32 const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
33 const int ADD = 0xd480d480; // half2 {-72, -72}
34
36
37 // for two fp16 from lowbit, subtract 1032 to get correct fp16 value
38 res.template AsType<half2_t>()(Number<0>{}) =
40
41 // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
42 res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
44
45 return res.template AsType<half4_t>()[Number<0>{}];
46}
47
48__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
49{
50 const int LO = 0x000f000f;
51 const int HI = 0x00f000f0;
52 const int EX = 0x64006400;
53
54 // Extract the two int4 at low bit and create two fp16 number.
55 int lo = amd_assembly_and_or_b32(q, LO, EX);
56 // Extract the two int4 at hight bit and create two fp16 number.
57 int hi = amd_assembly_and_or_b32(q, HI, EX);
58
59 const int SUB = 0xE408E408; // half2 {-1032, -1032}
60 const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
61 const int ADD = 0xd480d480; // half2 {-72, -72}
62
64
65 res.template AsType<half2_t>()(Number<0>{}) =
67
68 res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
70
71 asm volatile("v_pk_mul_f16 %0, %1, %2"
72 : "=v"(res.template AsType<half2_t>()(Number<0>{}))
73 : "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
74
75 asm volatile("v_pk_mul_f16 %0, %1, %2"
76 : "=v"(res.template AsType<half2_t>()(Number<1>{}))
77 : "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
78
79 return res.template AsType<half4_t>()[Number<0>{}];
80}
81
82__device__ inline f8x4_t i4_to_f8x4(int q)
83{
84 const int LO = 0x000f000f;
85 const int HI = 0x00f000f0;
86
87 int lo = amd_assembly_and_b32(q, LO);
88 int hi = amd_assembly_and_b32(q, HI);
89
90 float f32_0 = amd_assemble_cvt_f32_i4(lo);
91 float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
92 float f32_2 = amd_assemble_cvt_f32_i4(hi);
93 float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
94
95 return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
96}
97
98__device__ inline f8x8_t i4_to_fp8x8(int q)
99{
100#if defined(__gfx12__)
101 uint32_t fp8x4_0;
102 uint32_t fp8x4_1;
103 // todo: replace amd_assemble_cvt_f32_i4 with __builtin_amdgcn_cvt_off_f32_i4
104 float f32_0 = amd_assemble_cvt_f32_i4(q);
105 float f32_1 = amd_assemble_cvt_f32_i4(q >> 16);
106 fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_0, f32_1, 0, 0);
107 float f32_2 = amd_assemble_cvt_f32_i4(q >> 8);
108 float f32_3 = amd_assemble_cvt_f32_i4(q >> 24);
109 fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_2, f32_3, 0, 0);
110 q = q >> 4;
111 f32_0 = amd_assemble_cvt_f32_i4(q);
112 f32_1 = amd_assemble_cvt_f32_i4(q >> 16);
113 fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_0, f32_1, fp8x4_0, 1);
114 f32_2 = amd_assemble_cvt_f32_i4(q >> 8);
115 f32_3 = amd_assemble_cvt_f32_i4(q >> 24);
116 fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_2, f32_3, fp8x4_1, 1);
117 return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
118#elif defined(__gfx11__)
119 ignore = q;
120 return f8x8_t{};
121#else
122 return amd_assembly_i4_to_fp8x8(q);
123#endif
124}
125
126__device__ inline bhalf4_t i4_to_bhalf4(int q)
127{
128 uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
129
130 static constexpr uint32_t fp32_base = 0x4B000000;
131
132 float fp32_intermediates[4];
133
134 uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
135
136 fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
137 fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
138 fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
139 fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
140
141 fp32_intermediates[0] -= 8388616.f;
142 fp32_intermediates[1] -= 8388616.f;
143 fp32_intermediates[2] -= 8388616.f;
144 fp32_intermediates[3] -= 8388616.f;
145
147 res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(
148 __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
149 res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
150 __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
151
152 return res.template AsType<bhalf4_t>()[Number<0>{}];
153}
154
155namespace tensor_operation {
156namespace element_wise {
157
159{
160 static constexpr const char* name = "PassThroughPack8";
161
162 template <typename Y, typename X>
163 __host__ __device__ void operator()(Y& y, const X& x) const;
164
165 __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
166 {
167#if CK_USE_PK4_LAYOUT_SHUFFLE
169
170 result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
171 result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);
172
173 y = result.template AsType<half8_t>()[Number<0>{}];
174#else
177
178 dst.template AsType<half2_t>()(Number<0>{}) =
179 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
180 dst.template AsType<half2_t>()(Number<1>{}) =
181 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
182 dst.template AsType<half2_t>()(Number<2>{}) =
183 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
184 dst.template AsType<half2_t>()(Number<3>{}) =
185 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
186
187 y = dst.template AsType<half8_t>()[Number<0>{}];
188#endif
189 }
190
191 __host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
192 {
193#if CK_USE_PK4_LAYOUT_SHUFFLE
195
196#else
197 // Added pk_i4_t to f8x2_fnuz_t conversion
199 vector_type<float, 8> dst_tmp;
201
202 // pk_i4_t to float2_t conversion
203 dst_tmp.template AsType<float2_t>()(Number<0>{}) =
204 type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
205
206 dst_tmp.template AsType<float2_t>()(Number<1>{}) =
207 type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
208
209 dst_tmp.template AsType<float2_t>()(Number<2>{}) =
210 type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
211
212 dst_tmp.template AsType<float2_t>()(Number<3>{}) =
213 type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
214
215 // float to f8_t conversion
216 dst.template AsType<f8_t>()(Number<0>{}) =
217 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<0>{}]);
218 dst.template AsType<f8_t>()(Number<1>{}) =
219 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<1>{}]);
220
221 dst.template AsType<f8_t>()(Number<2>{}) =
222 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<2>{}]);
223 dst.template AsType<f8_t>()(Number<3>{}) =
224 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<3>{}]);
225
226 dst.template AsType<f8_t>()(Number<4>{}) =
227 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<4>{}]);
228 dst.template AsType<f8_t>()(Number<5>{}) =
229 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<5>{}]);
230
231 dst.template AsType<f8_t>()(Number<6>{}) =
232 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<6>{}]);
233 dst.template AsType<f8_t>()(Number<7>{}) =
234 type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<7>{}]);
235
236 y = dst.template AsType<f8x8_t>()[Number<0>{}];
237#endif
238 }
239
240 __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
241 {
242#if CK_USE_PK4_LAYOUT_SHUFFLE
244
245 result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
246 result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);
247
248 y = result.template AsType<bhalf8_t>()[Number<0>{}];
249#else
252
253 dst.template AsType<bhalf2_t>()(Number<0>{}) =
254 type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
255 dst.template AsType<bhalf2_t>()(Number<1>{}) =
256 type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
257 dst.template AsType<bhalf2_t>()(Number<2>{}) =
258 type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
259 dst.template AsType<bhalf2_t>()(Number<3>{}) =
260 type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
261
262 y = dst.template AsType<bhalf8_t>()[Number<0>{}];
263#endif
264 }
265 constexpr const static bool is_pack8_invocable = true;
266};
267
269{
270 static constexpr const char* name = "DequantPack8";
271
272 template <typename Y, typename X, typename Z>
273 __host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
274
275 __host__ __device__ constexpr void
276 operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
277 {
278#if CK_USE_PK4_LAYOUT_SHUFFLE
280
281 result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
282 result.template AsType<half4_t>()(Number<1>{}) =
284
285 y = result.template AsType<half8_t>()[Number<0>{}];
286#else
289
290 dst.template AsType<half2_t>()(Number<0>{}) =
291 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
292 dst.template AsType<half2_t>()(Number<1>{}) =
293 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
294 dst.template AsType<half2_t>()(Number<2>{}) =
295 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
296 dst.template AsType<half2_t>()(Number<3>{}) =
297 type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
298
299 y = dst.template AsType<half8_t>()[Number<0>{}];
300#endif
301 }
302
303 constexpr const static bool is_pack8_invocable = true;
304};
305
307{
308 static constexpr const char* name = "PassThroughPack2";
309
310 template <typename Y, typename X>
311 __host__ __device__ void operator()(Y& y, const X& x) const;
312
313 __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const
314 {
315 auto t = type_convert<float2_t>(x);
317 }
318
319 __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
320 {
321#if CK_USE_PK4_LAYOUT_SHUFFLE
323 uint8_t x_l = (x_u8 & 0x0f) >> 0;
324 uint8_t x_h = (x_u8 & 0xf0) >> 4;
325
326 auto l_f16 = ck::type_convert<ck::half_t>(x_l);
327 auto h_f16 = ck::type_convert<ck::half_t>(x_h);
328
329 y = {l_f16, h_f16};
330#else
333#endif
334 }
335
336 constexpr const static bool is_pack2_invocable = true;
337};
338
340{
341 static constexpr const char* name = "PassThrough";
342
343 template <typename Y, typename X>
344 __host__ __device__ void operator()(Y& y, const X& x) const;
345
346 template <>
347 __host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
348 {
349 y = x;
350 }
351
352 template <>
353 __host__ __device__ void operator()<f4x2_pk_t, f4x2_pk_t>(f4x2_pk_t& y,
354 const f4x2_pk_t& x) const
355 {
356 y = x;
357 }
358
359 template <>
360 __host__ __device__ void operator()<double, double>(double& y, const double& x) const
361 {
362 y = x;
363 }
364
365 template <>
366 __host__ __device__ void operator()<float, double>(float& y, const double& x) const
367 {
368 y = type_convert<float>(x);
369 }
370
371 template <>
372 __host__ __device__ void operator()<double, float>(double& y, const float& x) const
373 {
375 }
376
377 template <>
378 __host__ __device__ void operator()<float, float>(float& y, const float& x) const
379 {
380 y = x;
381 }
382
383 template <>
384 __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
385 {
386 y = x;
387 }
388
389 template <>
390 __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
391 {
393 }
394
395 template <>
396 __host__ __device__ void operator()<half_t, int32_t>(half_t& y, const int32_t& x) const
397 {
399 }
400
401 template <>
402 __host__ __device__ void operator()<float, int32_t>(float& y, const int32_t& x) const
403 {
404 y = type_convert<float>(x);
405 }
406
407 template <>
408 __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
409 {
410 y = x;
411 }
412
413 template <>
414 __host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
415 {
416 y = x;
417 }
418
419 template <>
420 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
421 {
423 }
424
425 template <>
426 __host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
427 {
428 y = type_convert<float>(x);
429 }
430
431 template <>
432 __host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
433 {
435 }
436
437 template <>
438 __host__ __device__ void operator()<float, half_t>(float& y, const half_t& x) const
439 {
440 y = type_convert<float>(x);
441 }
442
443 template <>
444 __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
445 {
446 y = x;
447 }
448
449 template <>
450 __host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
451 {
453 }
454
455 template <>
456 __host__ __device__ void operator()<bhalf_t, int8_t>(bhalf_t& y, const int8_t& x) const
457 {
459 }
460
461 template <>
462 __host__ __device__ void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
463 {
464 y = x;
465 }
466
467 template <>
468 __host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
469 {
471 }
472
473 template <>
474 __host__ __device__ void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
475 {
477 }
478
479 template <>
480 __host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
481 {
483 }
484
485 template <>
486 __host__ __device__ void operator()<float, int8_t>(float& y, const int8_t& x) const
487 {
488 y = type_convert<float>(x);
489 }
490
491#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
492 template <>
493 __host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
494 {
495 y = x;
496 }
497 template <>
498 __host__ __device__ void operator()<int4_t, int>(int4_t& y, const int& x) const
499 {
501 }
502#endif
503
504 template <>
505 __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
506 {
507 y = x;
508 }
509
510 template <>
511 __host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
512 {
513 y = type_convert<float>(x);
514 }
515
516 template <>
517 __host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
518 {
519 y = type_convert<f8_t>(x);
520 }
521
522 template <>
523 __host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
524 {
526 }
527
528 template <>
529 __host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
530 {
531 y = type_convert<f8_t>(x);
532 }
533
534 template <>
535 __host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const
536 {
537 y = x;
538 }
539
540 template <>
541 __host__ __device__ void operator()<float, bf8_t>(float& y, const bf8_t& x) const
542 {
543 y = type_convert<float>(x);
544 }
545
546 template <>
547 __host__ __device__ void operator()<bf8_t, float>(bf8_t& y, const float& x) const
548 {
549 y = type_convert<bf8_t>(x);
550 }
551
552 template <>
553 __host__ __device__ void operator()<half_t, bf8_t>(half_t& y, const bf8_t& x) const
554 {
556 }
557
558 template <>
559 __host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
560 {
561 y = type_convert<bf8_t>(x);
562 }
563};
564
566{
567 static constexpr const char* name = "UnaryConvert";
568
569 template <typename Y, typename X>
570 __host__ __device__ void operator()(Y& y, const X& x) const
571 {
572 y = type_convert<Y>(x);
573 }
574};
575
577{
578 static constexpr const char* name = "ConvertBF16RTN";
579
580 // convert to bf16 using round to nearest (rtn)
581 template <typename Y, typename X>
582 __host__ __device__ void operator()(Y& y, const X& x) const
583 {
584 // check Y datatype
585 static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
586
587 // check X datatype
589 "Data type is not supported by this operation!");
590
591 y = bf16_convert_rtn<Y>(x);
592 }
593};
594
596{
597 static constexpr const char* name = "ConvertF8SR";
598
599 // convert to fp8 using stochastic rounding (SR)
600 template <typename Y, typename X>
601 __host__ __device__ void operator()(Y& y, const X& x) const
602 {
603 // check Y datatype
605 "Data type is not supported by this operation!");
606
607 // check X datatype
609 "Data type is not supported by this operation!");
610
611 y = f8_convert_sr<Y>(x);
612 }
613};
614
616{
617 static constexpr const char* name = "ConvertF8RNE";
618
619 // convert to fp8 using rounding to nearest even
620 template <typename Y, typename X>
621 __host__ __device__ void operator()(Y& y, const X& x) const
622 {
623 // check Y datatype
625 "Data type is not supported by this operation!");
626
627 // check X datatype
629 "Data type is not supported by this operation!");
630
631 y = f8_convert_rne<Y>(x);
632 }
633};
634
635struct Scale
636{
637 static constexpr const char* name = "Scale";
638
639 __host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
640
641 template <typename Y, typename X>
642 __host__ __device__ void operator()(Y& y, const X& x) const
643 {
645 }
646
647 template <>
648 __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
649 {
651 };
652
653 template <>
654 __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
655 {
656 const float x_tmp = type_convert<float>(x);
657 const float y_tmp = scale_ * x_tmp;
658 y = type_convert<bhalf_t>(y_tmp);
659 };
660
661 template <>
662 __host__ __device__ void operator()<float, float>(float& y, const float& x) const
663 {
664 y = scale_ * x;
665 };
666
667 template <>
668 __host__ __device__ void operator()<double, double>(double& y, const double& x) const
669 {
670 y = scale_ * x;
671 };
672
673 template <>
674 __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
675 {
677 };
678
679 float scale_;
680};
681
683{
684 static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
685
686 __host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
687
688 template <typename Y, typename X>
689 __host__ __device__ void operator()(Y& y, const X& x) const;
690
691 template <>
692 __host__ __device__ void operator()<float, float>(float& y, const float& x) const
693 {
694 y = math::isnan(x) ? -NumericLimits<float>::Infinity() : scale_ * x;
695 };
696
697 float scale_;
698};
699
701{
702 static constexpr const char* name = "UnaryDivide";
703
704 __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
705
706 template <typename T>
707 __host__ __device__ void operator()(T& y, const T& x) const
708 {
711 "Data type is not supported by this operation!");
712
713 y = x / type_convert<T>(divider_);
714 };
715
716 template <>
717 __host__ __device__ void operator()<half_t>(half_t& y, const half_t& x) const
718 {
719 float x_ = type_convert<float>(x);
720 float divider_f_ = type_convert<float>(divider_);
721
722 y = type_convert<half_t>(x_ / divider_f_);
723 };
724
725 template <>
726 __host__ __device__ void operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x) const
727 {
728 float x_ = type_convert<float>(x);
729 float divider_f_ = type_convert<float>(divider_);
730
731 y = type_convert<bhalf_t>(x_ / divider_f_);
732 };
733
734 template <>
735 __host__ __device__ void operator()<f8_t>(f8_t& y, const f8_t& x) const
736 {
737 float x_ = type_convert<float>(x);
738 float divider_f_ = type_convert<float>(divider_);
739
740 y = type_convert<f8_t>(x_ / divider_f_);
741 };
742
744};
745
747{
748 static constexpr const char* name = "UnarySquare";
749
750 template <typename T>
751 __host__ __device__ void operator()(T& y, const T& x) const
752 {
755#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
757#endif
758 ,
759 "Data type is not supported by this operation!");
760 y = x * x;
761 };
762};
763
765{
766 static constexpr const char* name = "UnaryAbs";
767
768 template <typename T>
769 __host__ __device__ void operator()(T& y, const T& x) const
770 {
771
775 "Data type is not supported by this operation!");
776
777 y = math::abs(x);
778 };
779
780 template <>
781 __host__ __device__ void operator()(f8_t& y, const f8_t& x) const
782 {
783 y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
784 };
785
786 template <typename Y, typename X>
787 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
788
789 template <>
790 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
791 {
792 y = ck::type_convert<bhalf_t>(ck::math::abs(x));
793 };
794};
795
797{
798 static constexpr const char* name = "UnarySqrt";
799
800 template <typename T>
801 __host__ __device__ void operator()(T& y, const T& x) const
802 {
804 "Data type is not supported by this operation!");
805
806 y = math::sqrt(x);
807 };
808};
809
810struct Clamp
811{
812 static constexpr const char* name = "Clamp";
813
814 Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
815 : floor_(floor), ceil_(ceil){};
816
817 template <typename Y, typename X>
818 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
819
820 template <>
821 __host__ __device__ constexpr void operator()<float, float>(float& y, const float& x) const
822 {
823 const float& a = x;
824 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
825 };
826
827 template <>
828 __host__ __device__ constexpr void operator()<double, double>(double& y, const double& x) const
829 {
830 const double& a = x;
831 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
832 };
833
834 template <>
835 __host__ __device__ constexpr void operator()<half_t, half_t>(half_t& y, const half_t& x) const
836 {
837 const float a = type_convert<half_t>(x);
838 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
840 };
841
842 template <>
843 __host__ __device__ constexpr void operator()<half_t, float>(half_t& y, const float& x) const
844 {
845 const float& a = x;
846 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
848 };
849
850 template <>
851 __host__ __device__ constexpr void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
852 {
853 const float& a = x;
854 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
856 };
857
858 template <>
859 __host__ __device__ constexpr void operator()<bhalf_t, bhalf_t>(bhalf_t& y,
860 const bhalf_t& x) const
861 {
862 const float a = type_convert<float>(x);
863 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
865 };
866
867 template <>
868 __host__ __device__ constexpr void operator()<int, int>(int& y, const int& x) const
869 {
870 const int8_t& a = x;
871 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
872 };
873
874 template <>
875 __host__ __device__ constexpr void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
876 {
877 const int8_t& a = x;
878 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
879 };
880
881 const float floor_;
882 const float ceil_;
883};
884
885struct Relu
886{
887 static constexpr const char* name = "Relu";
888
889 template <typename T>
890 __host__ __device__ void operator()(T& y, const T& x) const
891 {
895 "Data type is not supported by this operation!");
896 y = x > 0 ? x : 0;
897 }
898
899 template <typename Y, typename X>
900 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
901
902 template <>
903 __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
904 {
905 float x_f32 = type_convert<float>(x);
906 float y_f32 = x_f32 > 0 ? x_f32 : 0;
907 y = type_convert<bhalf_t>(y_f32);
908 }
909
910 template <>
911 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
912 {
913 float y_f32 = x > 0 ? x : 0;
914 y = type_convert<bhalf_t>(y_f32);
915 };
916};
917
918// Fast GeLU
919// https://paperswithcode.com/method/gelu
920// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
921// host code use higher accuracy "exp" and "div"
922// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
924{
925 static constexpr const char* name = "FastGelu";
926
927 template <typename Y, typename X>
928 __host__ void operator()(Y& y, const X& x) const;
929
930 template <typename Y, typename X>
931 __device__ void operator()(Y& y, const X& x) const;
932#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
933 template <>
934 __host__ void operator()<float, float>(float& y, const float& x) const
935 {
936 // const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
937 const float c1 = -2.0 * 0.035677f;
938 const float c2 = -2.0 * 0.797885f;
939 const float u = x * (c1 * x * x + c2);
940 const float emu = exp(u);
941 y = x / (1.f + emu);
942 }
943#endif
944 // device code, use lower precision "__ocml_exp_f32" and "rcp"
945 template <>
946 __device__ void operator()<float, float>(float& y, const float& x) const
947 {
948 // const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
949 const float c1 = -2.0 * 0.035677f;
950 const float c2 = -2.0 * 0.797885f;
951 const float u = x * (c1 * x * x + c2);
952 const float emu = __ocml_exp_f32(u);
953
954 y = x * math::rcp(1.f + emu);
955 }
956
957 template <>
958 __host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
959 {
960 float y_f;
961
962 this->operator()<float, float>(y_f, type_convert<float>(x));
963
964 y = type_convert<half_t>(y_f);
965 }
966
967 template <>
968 __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
969 {
970 float y_f;
971
972 this->operator()<float, float>(y_f, type_convert<float>(x));
973
974 y = type_convert<half_t>(y_f);
975 }
976
977 template <>
978 __host__ void operator()<half_t, float>(half_t& y, const float& x) const
979 {
980 float y_f;
981
982 this->operator()<float, float>(y_f, x);
983
984 y = type_convert<half_t>(y_f);
985 }
986
987 template <>
988 __device__ void operator()<half_t, float>(half_t& y, const float& x) const
989 {
990 float y_f;
991
992 this->operator()<float, float>(y_f, x);
993
994 y = type_convert<half_t>(y_f);
995 }
996
997 template <>
998 __host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
999 {
1000 float y_f;
1001
1002 this->operator()<float, float>(y_f, x);
1003
1004 y = type_convert<bhalf_t>(y_f);
1005 }
1006
1007 template <>
1008 __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1009 {
1010 float y_f;
1011
1012 this->operator()<float, float>(y_f, x);
1013
1014 y = type_convert<bhalf_t>(y_f);
1015 }
1016
1017 template <>
1018 __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
1019 {
1020 float y_f;
1021
1022 this->operator()<float, float>(y_f, type_convert<float>(x));
1023
1024 y = type_convert<bhalf_t>(y_f);
1025 }
1026
1027 template <>
1028 __host__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
1029 {
1030 float y_f;
1031
1032 this->operator()<float, float>(y_f, type_convert<float>(x));
1033
1034 y = type_convert<bhalf_t>(y_f);
1035 }
1036};
1037
1038// https://paperswithcode.com/method/gelu
1039// y = 0.5*x*(1+erf(x/sqrt(2)))
1040struct Gelu
1041{
1042 static constexpr const char* name = "Gelu";
1043
1044 template <typename Y, typename X>
1045 __host__ __device__ void operator()(Y& y, const X& x) const;
1046
1047 template <>
1048 __host__ __device__ void operator()<float, float>(float& y, const float& x) const
1049 {
1050 y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
1051 }
1052
1053 template <>
1054 __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
1055 {
1056 y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x))));
1057 }
1058};
1059
1061{
1062 static constexpr const char* name = "Sigmoid";
1063
1064 template <typename T>
1065 __host__ __device__ void operator()(T& y, const T& x) const
1066 {
1070 "Data type is not supported by this operation!");
1071 constexpr T one = type_convert<T>(1);
1072 y = one / (one + math::exp(-x));
1073 };
1074
1075 template <typename Y, typename X>
1076 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1077
1078 template <>
1079 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1080 {
1081 constexpr float one = 1.f;
1082 y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
1083 };
1084};
1085
1086struct Silu
1087{
1088 static constexpr const char* name = "SiLU";
1089
1090 template <typename T>
1091 __host__ __device__ void operator()(T& y, const T& x) const
1092 {
1095 "Data type is not supported by this operation!");
1096 constexpr T one = type_convert<T>(1);
1097 y = x * (one / (one + math::exp(-x)));
1098 };
1099};
1100
1101struct TanH
1102{
1103 static constexpr const char* name = "TanH";
1104
1105 template <typename T>
1106 __host__ __device__ void operator()(T& y, const T& x) const
1107 {
1111 "Data type is not supported by this operation!");
1112
1113 y = math::tanh(x);
1114 };
1115
1116 template <typename Y, typename X>
1117 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1118
1119 template <>
1120 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1121 {
1123 };
1124};
1125
1126struct ACos
1127{
1128 static constexpr const char* name = "ACos";
1129
1130 template <typename T>
1131 __host__ __device__ void operator()(T& y, const T& x) const
1132 {
1136 "Data type is not supported by this operation!");
1137
1138 y = math::acos(x);
1139 };
1140};
1141
1142struct Neg
1143{
1144 static constexpr const char* name = "Neg";
1145
1146 template <typename T>
1147 __host__ __device__ void operator()(T& y, const T& x) const
1148 {
1152 "Data type is not supported by this operation!");
1153
1154 y = math::neg(x);
1155 };
1156};
1157
1158struct ATan
1159{
1160 static constexpr const char* name = "ATan";
1161
1162 template <typename T>
1163 __host__ __device__ void operator()(T& y, const T& x) const
1164 {
1168 "Data type is not supported by this operation!");
1169
1170 y = math::atan(x);
1171 };
1172};
1173
1174struct Sin
1175{
1176 static constexpr const char* name = "Sin";
1177
1178 template <typename T>
1179 __host__ __device__ void operator()(T& y, const T& x) const
1180 {
1184 "Data type is not supported by this operation!");
1185
1186 y = math::sin(x);
1187 };
1188};
1189
1190struct ASinH
1191{
1192 static constexpr const char* name = "ASinH";
1193
1194 template <typename T>
1195 __host__ __device__ void operator()(T& y, const T& x) const
1196 {
1200 "Data type is not supported by this operation!");
1201
1202 y = math::asinh(x);
1203 };
1204};
1205
1206struct Cos
1207{
1208 static constexpr const char* name = "Cos";
1209
1210 template <typename T>
1211 __host__ __device__ void operator()(T& y, const T& x) const
1212 {
1216 "Data type is not supported by this operation!");
1217
1218 y = cos(x);
1219 };
1220};
1221
1222struct ACosH
1223{
1224 static constexpr const char* name = "ACosH";
1225
1226 template <typename T>
1227 __host__ __device__ void operator()(T& y, const T& x) const
1228 {
1232 "Data type is not supported by this operation!");
1233
1234 y = math::acosh(x);
1235 };
1236};
1237
1238struct Tan
1239{
1240 static constexpr const char* name = "Tan";
1241
1242 template <typename T>
1243 __host__ __device__ void operator()(T& y, const T& x) const
1244 {
1248 "Data type is not supported by this operation!");
1249
1250 y = math::tan(x);
1251 };
1252};
1253
1254struct ATanH
1255{
1256 static constexpr const char* name = "ATanH";
1257
1258 template <typename T>
1259 __host__ __device__ void operator()(T& y, const T& x) const
1260 {
1264 "Data type is not supported by this operation!");
1265
1266 y = math::atanh(x);
1267 };
1268};
1269
1270struct SinH
1271{
1272 static constexpr const char* name = "SinH";
1273
1274 template <typename T>
1275 __host__ __device__ void operator()(T& y, const T& x) const
1276 {
1280 "Data type is not supported by this operation!");
1281
1282 y = math::sinh(x);
1283 };
1284};
1285
1286struct Ceil
1287{
1288 static constexpr const char* name = "Ceil";
1289
1290 template <typename T>
1291 __host__ __device__ void operator()(T& y, const T& x) const
1292 {
1296 "Data type is not supported by this operation!");
1297
1298 y = math::ceil(x);
1299 };
1300};
1301
1302struct Exp
1303{
1304 static constexpr const char* name = "Exp";
1305
1306 template <typename T>
1307 __host__ __device__ void operator()(T& y, const T& x) const
1308 {
1312 "Data type is not supported by this operation!");
1313
1314 y = math::exp(x);
1315 };
1316};
1317
1318struct CosH
1319{
1320 static constexpr const char* name = "CosH";
1321
1322 template <typename T>
1323 __host__ __device__ void operator()(T& y, const T& x) const
1324 {
1328 "Data type is not supported by this operation!");
1329
1330 y = math::cosh(x);
1331 };
1332};
1333
1334struct Floor
1335{
1336 static constexpr const char* name = "Floor";
1337
1338 template <typename T>
1339 __host__ __device__ void operator()(T& y, const T& x) const
1340 {
1344 "Data type is not supported by this operation!");
1345
1346 y = math::floor(x);
1347 };
1348};
1349
1350struct Log
1351{
1352 static constexpr const char* name = "Log";
1353
1354 template <typename T>
1355 __host__ __device__ void operator()(T& y, const T& x) const
1356 {
1360 "Data type is not supported by this operation!");
1361
1362 y = math::log(x);
1363 };
1364};
1365
1366struct ASin
1367{
1368 static constexpr const char* name = "ASin";
1369
1370 template <typename T>
1371 __host__ __device__ void operator()(T& y, const T& x) const
1372 {
1376 "Data type is not supported by this operation!");
1377
1378 y = math::asin(x);
1379 };
1380};
1381
1382struct Rcp
1383{
1384 static constexpr const char* name = "Rcp";
1385
1386 template <typename T>
1387 __host__ __device__ void operator()(T& y, const T& x) const
1388 {
1392 "Data type is not supported by this operation!");
1393
1394 y = math::rcp(x);
1395 };
1396};
1397
1398struct Swish
1399{
1400 static constexpr const char* name = "Swish";
1401
1402 Swish(float beta = 1.0f) : beta_(beta) {}
1403
1404 template <typename Y, typename X>
1405 __host__ __device__ void operator()(Y& y, const X& x) const
1406 {
1409 "Data type is not supported by this operation!");
1410
1413 "Data type is not supported by this operation!");
1414
1415 float bx = -beta_ * type_convert<float>(x);
1416 y = type_convert<Y>(x / (1.f + math::exp(bx)));
1417 };
1418
1419 template <>
1420 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1421 {
1422 float bx = -beta_ * x;
1423 y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
1424 };
1425
1426 const float beta_;
1427};
1428
1430{
1431 static constexpr const char* name = "SoftRelu";
1432
1433 SoftRelu(float alpha = 1.f) : alpha_(alpha){};
1434
1435 template <typename T>
1436 __host__ __device__ void operator()(T& y, const T& x) const
1437 {
1441 "Data type is not supported by this operation!");
1442 T casted_alpha = type_convert<T>(alpha_);
1443 constexpr T one = type_convert<T>(1);
1444 y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
1445 }
1446
1447 template <typename Y, typename X>
1448 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1449
1450 template <>
1451 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1452 {
1453 constexpr float one = 1.f;
1455 };
1456 const float alpha_;
1457};
1458
1459struct Power
1460{
1461 static constexpr const char* name = "Power";
1462
1463 Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
1464 : alpha_(alpha), beta_(beta), gamma_(gamma){};
1465
1466 template <typename T>
1467 __host__ __device__ void operator()(T& y, const T& x) const
1468 {
1472 "Data type is not supported by this operation!");
1473 T casted_alpha = type_convert<T>(alpha_);
1474 T casted_beta = type_convert<T>(beta_);
1475 T casted_gamma = type_convert<T>(gamma_);
1476 T shifted_scaled_x = casted_alpha + casted_beta * x;
1477 y = math::pow(shifted_scaled_x, casted_gamma);
1478 }
1479
1480 template <typename Y, typename X>
1481 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1482
1483 template <>
1484 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1485 {
1486 const float shifted_scaled_x = alpha_ + beta_ * x;
1487 y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
1488 };
1489
1490 const float alpha_;
1491 const float beta_;
1492 const float gamma_;
1493};
1494
1496{
1497 static constexpr const char* name = "ClippedRelu";
1498
1499 ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
1500
1501 template <typename T>
1502 __host__ __device__ void operator()(T& y, const T& x) const
1503 {
1507 "Data type is not supported by this operation!");
1508 T casted_alpha = type_convert<T>(alpha_);
1509 T casted_beta = type_convert<T>(beta_);
1510 y = math::min(casted_beta, math::max(casted_alpha, x));
1511 }
1512
1513 template <typename Y, typename X>
1514 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1515
1516 template <>
1517 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1518 {
1520 };
1521
1522 const float alpha_;
1523 const float beta_;
1524};
1525
1527{
1528 static constexpr const char* name = "LeakyRelu";
1529
1530 LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
1531
1532 template <typename T>
1533 __host__ __device__ void operator()(T& y, const T& x) const
1534 {
1538 "Data type is not supported by this operation!");
1539 T casted_alpha = type_convert<T>(alpha_);
1540 y = x >= 0 ? x : x * casted_alpha;
1541 }
1542
1543 template <typename Y, typename X>
1544 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1545
1546 template <>
1547 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1548 {
1549 y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
1550 };
1551
1552 const float alpha_;
1553};
1554
1555struct Elu
1556{
1557 static constexpr const char* name = "Elu";
1558
1559 Elu(float alpha = 1.f) : alpha_(alpha){};
1560
1561 template <typename T>
1562 __host__ __device__ void operator()(T& y, const T& x) const
1563 {
1567 "Data type is not supported by this operation!");
1568 T casted_alpha = type_convert<T>(alpha_);
1569 y = x > 0 ? x : casted_alpha * math::expm1(x);
1570 }
1571
1572 template <typename Y, typename X>
1573 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1574
1575 template <>
1576 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1577 {
1578 y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
1579 };
1580
1581 const float alpha_;
1582};
1583
1585{
1586 static constexpr const char* name = "Logistic";
1587
1588 Logistic(float alpha = 1.f) : alpha_(alpha){};
1589
1590 template <typename T>
1591 __host__ __device__ void operator()(T& y, const T& x) const
1592 {
1596 "Data type is not supported by this operation!");
1597 T casted_alpha = type_convert<T>(alpha_);
1598 constexpr T one = type_convert<T>(1);
1599 y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
1600 }
1601
1602 template <typename Y, typename X>
1603 __host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1604
1605 template <>
1606 __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1607 {
1608 constexpr float one = 1.f;
1609 y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
1610 };
1611 const float alpha_;
1612};
1613
1615{
1616 static constexpr const char* name = "ConvInvscale";
1617
1618 __host__ __device__ ConvInvscale(float scale_in = 1.f,
1619 float scale_wei = 1.f,
1620 float scale_out = 1.f)
1621 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
1622 {
1623 }
1624
1625 template <typename E, typename C>
1626 __host__ __device__ void operator()(E& e, const C& c) const;
1627
1628 template <>
1629 __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
1630 {
1632 };
1633
1637};
1638
1640{
1641 static constexpr const char* name = "ConvScale";
1642
1643 __host__ __device__ ConvScale(float scale_in = 1.f,
1644 float scale_wei = 1.f,
1645 float scale_out = 1.f)
1646 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
1647 {
1648 }
1649
1650 template <typename E, typename C>
1651 __host__ __device__ void operator()(E& e, const C& c) const;
1652
1653 template <>
1654 __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
1655 {
1657 };
1658
1662};
1663
1665{
1666 static constexpr const char* name = "ConvScaleRelu";
1667
1668 __host__ __device__ ConvScaleRelu(float scale_in = 1.f,
1669 float scale_wei = 1.f,
1670 float scale_out = 1.f)
1671 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
1672 {
1673 }
1674
1675 template <typename E, typename C>
1676 __host__ __device__ void operator()(E& e, const C& c) const;
1677
1678 template <>
1679 __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
1680 {
1681 float x;
1682 Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
1684 };
1685
1689};
1690
1691// support fastconvert of int8 to fp16
1692
1693template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
1695{
1696};
1697
1698template <>
1700{
1703
1704 __device__ static OutputArray convert(InputArray const& Input)
1705 {
1706 OutputArray Output;
1707
1708 uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
1709 uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
1710
1711 static constexpr uint32_t byte_selector_01 = 0x05010500;
1712 static constexpr uint32_t byte_selector_23 = 0x05030502;
1713 static constexpr uint32_t fp16_adder = 0x64646464;
1714 half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
1715 half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
1716
1717 static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
1718 asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
1719 : "=v"(half_2[0])
1720 : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
1721 asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
1722 : "=v"(half_2[1])
1723 : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
1724
1725 return Output;
1726 }
1727
1728 __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
1729};
1730
1731template <index_t N>
1733{
1734 static constexpr int VEC_WIDTH = 4;
1735 static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
1736
1739
1740 __device__ static OutputArray convert(InputArray const& Input)
1741 {
1743
1744 OutputArray Output;
1745
1746 using Vec_InputArray = vector_type<uint8_t, 4>;
1747 using Vec_OutputArray = vector_type<half_t, 4>;
1748
1749 Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
1750 Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
1751
1752 static_for<0, N / VEC_WIDTH, 1>{}(
1753 [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
1754
1755 return Output;
1756 }
1757
1758 __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
1759};
1760
1762{
1763 static constexpr const char* name = "DynamicUnaryOp";
1764
1765 __host__ __device__ DynamicUnaryOp() = delete;
1766
1767 __host__ __device__ DynamicUnaryOp(const Swish& swish)
1768 : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
1769 {
1770 }
1771
1772 __host__ __device__ DynamicUnaryOp(const Swish&& swish)
1773 : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
1774 {
1775 }
1776
1777 __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {}
1778
1779 __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {}
1780
1781 __host__ __device__ DynamicUnaryOp(const PassThrough&)
1782 : unary_op_type_(UnaryOpType::PassThrough)
1783 {
1784 }
1785
1786 __host__ __device__ DynamicUnaryOp(const PassThrough&&)
1787 : unary_op_type_(UnaryOpType::PassThrough)
1788 {
1789 }
1790
1791 __host__ __device__ DynamicUnaryOp(const Logistic& logistic)
1792 : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
1793 {
1794 }
1795
1796 __host__ __device__ DynamicUnaryOp(const Logistic&& logistic)
1797 : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
1798 {
1799 }
1800
1801 __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {}
1802
1803 __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {}
1804
1805 __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {}
1806
1807 __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {}
1808
1809 __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu)
1810 : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
1811 {
1812 }
1813
1814 __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu)
1815 : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
1816 {
1817 }
1818
1819 __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
1820
1821 __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
1822
1823 __host__ __device__ DynamicUnaryOp(const Power& pow)
1824 : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
1825 {
1826 }
1827
1828 __host__ __device__ DynamicUnaryOp(const Power&& pow)
1829 : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
1830 {
1831 }
1832
1833 __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu)
1834 : unary_op_type_(UnaryOpType::ClippedRelu),
1835 clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
1836 {
1837 }
1838
1839 __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu)
1840 : unary_op_type_(UnaryOpType::ClippedRelu),
1841 clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
1842 {
1843 }
1844
1845 __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu)
1846 : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
1847 {
1848 }
1849
1850 __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu)
1851 : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
1852 {
1853 }
1854
1855 __host__ __device__ DynamicUnaryOp(const Elu& elu)
1856 : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
1857 {
1858 }
1859
1860 __host__ __device__ DynamicUnaryOp(const Elu&& elu)
1861 : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
1862 {
1863 }
1864
1865 __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default;
1866
1867 __host__ __device__ ~DynamicUnaryOp() {}
1868
1869 template <typename Y, typename X>
1870 __host__ __device__ void operator()(Y& y, const X& x) const
1871 {
1872 switch(unary_op_type_)
1873 {
1874 case(UnaryOpType::Swish): swish_(y, x); break;
1875 case(UnaryOpType::Sigmoid): sigmoid_(y, x); break;
1876 case(UnaryOpType::PassThrough): pass_through_(y, x); break;
1877 case(UnaryOpType::Logistic): logistic_(y, x); break;
1878 case(UnaryOpType::TanH): tanh_(y, x); break;
1879 case(UnaryOpType::Relu): relu_(y, x); break;
1880 case(UnaryOpType::SoftRelu): soft_relu_(y, x); break;
1881 case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break;
1882 case(UnaryOpType::Power): power_(y, x); break;
1883 case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break;
1884 case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break;
1885 case(UnaryOpType::Elu): elu_(y, x); break;
1886 default: break;
1887 }
1888 }
1889
1890 template <>
1891 __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
1892 {
1893 float y_float;
1894 float x_float = type_convert<float>(x);
1895 this->operator()(y_float, x_float);
1896 y = type_convert<bhalf_t>(y_float);
1897 }
1898
1899 private:
1900 enum class UnaryOpType
1901 {
1902 Swish,
1903 Sigmoid,
1905 Logistic,
1906 TanH,
1907 Relu,
1908 SoftRelu,
1909 UnaryAbs,
1910 Power,
1912 LeakyRelu,
1913 Elu
1914 };
1915
1916 public:
1917 UnaryOpType unary_op_type_;
1918
1931};
1932
1933} // namespace element_wise
1934} // namespace tensor_operation
1935} // namespace ck
__host__ T log(T x)
Definition math_v2.hpp:409
__host__ T cosh(T x)
Definition math_v2.hpp:349
__host__ T exp(T x)
Definition math_v2.hpp:391
__host__ T rcp(T x)
Definition math_v2.hpp:385
__host__ T tan(T x)
Definition math_v2.hpp:277
__host__ T sin(T x)
Definition math_v2.hpp:187
__host__ T expm1(T x)
Definition math_v2.hpp:446
__host__ T atan(T x)
Definition math_v2.hpp:169
__host__ T acos(T x)
Definition math_v2.hpp:121
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ T ceil(T x)
Definition math_v2.hpp:331
__host__ T asin(T x)
Definition math_v2.hpp:205
__host__ T pow(T x, T gamma)
Definition math_v2.hpp:427
__host__ T neg(T x)
Definition math_v2.hpp:139
__host__ T floor(T x)
Definition math_v2.hpp:367
__host__ T asinh(T x)
Definition math_v2.hpp:223
__host__ T atanh(T x)
Definition math_v2.hpp:295
__host__ T tanh(T x)
Definition math_v2.hpp:103
__host__ T acosh(T x)
Definition math_v2.hpp:259
__host__ T sinh(T x)
Definition math_v2.hpp:313
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition amd_inline_asm.hpp:35
__device__ half4_t i4_to_half4(int q)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:20
__device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
Definition amd_inline_asm.hpp:59
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition dtype_vector.hpp:2162
__host__ __device__ constexpr Y bf16_convert_rtn(X x)
__device__ half4_t i4_to_half4_scale(int q, const ck::half2_t &scale)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:48
__host__ __device__ constexpr Y f8_convert_rne(X x)
_Float16 half_t
Definition data_type.hpp:31
__device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
Definition amd_inline_asm.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ int amd_assembly_and_b32(int a, int b)
Definition amd_inline_asm.hpp:14
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition amd_inline_asm.hpp:28
@ MUL
Definition reduction_enums.hpp:11
@ ADD
Definition reduction_enums.hpp:10
__host__ __device__ constexpr Y f8_convert_sr(X x)
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
__device__ f8x4_t i4_to_f8x4(int q)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:82
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition dtype_vector.hpp:2161
bf8_fnuz_t bf8_t
Definition amd_ck_fp8.hpp:1763
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
constexpr bool is_same_v
Definition type.hpp:283
_BitInt(4) int4_t
Definition data_type.hpp:32
typename vector_type< pk_i4_t, 4 >::type pk_i4x4_t
Definition dtype_vector.hpp:2282
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition amd_inline_asm.hpp:21
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ float amd_assemble_cvt_f32_i4(int b)
Definition amd_inline_asm.hpp:42
typename vector_type< half_t, 4 >::type half4_t
Definition dtype_vector.hpp:2154
__device__ bhalf4_t i4_to_bhalf4(int q)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:126
__device__ f8x8_t i4_to_fp8x8(int q)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:98
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136
signed char int8_t
Definition stdint.h:121
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
__host__ static __device__ constexpr T Infinity()
Definition numeric_limits.hpp:317
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:42
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1223
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1227
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1224
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1127
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1128
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1131
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1191
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1195
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1192
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1367
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1368
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1371
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1255
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1259
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1256
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1159
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1160
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1163
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1287
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1291
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1288
Clamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:814
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
const float floor_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:881
const float ceil_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:882
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:812
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1496
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1502
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
const float beta_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1523
const float alpha_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1522
ClippedRelu(float alpha=0.f, float beta=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1499
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1497
float scale_wei_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1635
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1616
float scale_out_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1636
float scale_in_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1634
__host__ __device__ void operator()(E &e, const C &c) const
__host__ __device__ ConvInvscale(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1618
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1641
__host__ __device__ void operator()(E &e, const C &c) const
float scale_out_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1661
float scale_wei_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1660
float scale_in_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1659
__host__ __device__ ConvScale(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1643
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1666
float scale_in_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1686
float scale_wei_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1687
__host__ __device__ void operator()(E &e, const C &c) const
__host__ __device__ ConvScaleRelu(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1668
float scale_out_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1688
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:577
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:582
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:578
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:616
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:617
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:621
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:596
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:597
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:601
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1319
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1323
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1320
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1207
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1208
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1211
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:269
constexpr static const bool is_pack8_invocable
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:303
__host__ __device__ void operator()(Y &y, const X &x, const Z &z) const
__host__ __device__ constexpr void operator()(ck::half8_t &y, const ck::pk_i4x4_t &x, const ck::half2_t &z) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:276
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:270
__host__ __device__ DynamicUnaryOp(const TanH &&)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1803
__host__ __device__ DynamicUnaryOp(const LeakyRelu &leakyrelu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1845
__host__ __device__ DynamicUnaryOp(const Power &&pow)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1828
__host__ __device__ DynamicUnaryOp(const UnaryAbs &&)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1821
__host__ __device__ DynamicUnaryOp(const Swish &&swish)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1772
__host__ __device__ DynamicUnaryOp(const TanH &)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1801
__host__ __device__ DynamicUnaryOp(const DynamicUnaryOp &dynamic_op)=default
__host__ __device__ DynamicUnaryOp(const Sigmoid &&)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1779
__host__ __device__ DynamicUnaryOp(const Swish &swish)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1767
TanH tanh_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1923
__host__ __device__ DynamicUnaryOp(const Sigmoid &)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1777
__host__ __device__ DynamicUnaryOp(const SoftRelu &&softrelu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1814
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1870
__host__ __device__ DynamicUnaryOp(const UnaryAbs &)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1819
__host__ __device__ DynamicUnaryOp(const LeakyRelu &&leakyrelu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1850
__host__ __device__ DynamicUnaryOp(const Relu &&)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1807
LeakyRelu leaky_relu_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1929
Logistic logistic_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1922
Swish swish_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1919
__host__ __device__ DynamicUnaryOp(const Elu &elu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1855
UnaryOpType unary_op_type_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1917
UnaryAbs unary_abs_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1926
__host__ __device__ ~DynamicUnaryOp()
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1867
__host__ __device__ DynamicUnaryOp(const PassThrough &)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1781
__host__ __device__ DynamicUnaryOp(const Power &pow)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1823
Relu relu_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1924
__host__ __device__ DynamicUnaryOp(const PassThrough &&)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1786
__host__ __device__ DynamicUnaryOp(const Logistic &logistic)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1791
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1763
SoftRelu soft_relu_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1925
__host__ __device__ DynamicUnaryOp(const SoftRelu &softrelu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1809
__host__ __device__ DynamicUnaryOp(const ClippedRelu &clippedrelu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1833
Power power_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1927
__host__ __device__ DynamicUnaryOp(const Elu &&elu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1860
__host__ __device__ DynamicUnaryOp(const ClippedRelu &&clippedrelu)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1839
Elu elu_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1930
__host__ __device__ DynamicUnaryOp(const Relu &)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1805
Sigmoid sigmoid_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1920
__host__ __device__ DynamicUnaryOp(const Logistic &&logistic)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1796
ClippedRelu clipped_relu_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1928
PassThrough pass_through_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1921
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1556
const float alpha_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1581
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1562
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1557
Elu(float alpha=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1559
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1303
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1307
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1304
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:924
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:925
__device__ void operator()(Y &y, const X &x) const
__host__ void operator()(Y &y, const X &x) const
vector_type< uint8_t, 4 > InputArray
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1701
vector_type< half_t, 4 > OutputArray
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1702
__device__ OutputArray operator()(InputArray const &Input)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1728
static __device__ OutputArray convert(InputArray const &Input)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1704
vector_type< uint8_t, N > InputArray
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1737
static constexpr int VEC_WIDTH
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1734
vector_type< half_t, N > OutputArray
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1738
static __device__ OutputArray convert(InputArray const &Input)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1740
__device__ OutputArray operator()(InputArray const &Input)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1758
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1695
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1335
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1339
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1336
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1041
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1042
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1527
const float alpha_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1552
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1533
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1528
LeakyRelu(float alpha=0.01f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1530
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1351
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1352
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1355
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1585
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1586
Logistic(float alpha=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1588
const float alpha_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1611
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1591
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1143
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1144
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1147
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:341
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:307
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:308
constexpr static const bool is_pack2_invocable
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:336
__host__ __device__ constexpr void operator()(ck::half2_t &y, const ck::pk_i4_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:319
__host__ __device__ constexpr void operator()(half2_t &y, const f8x2_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:313
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:159
__host__ __device__ constexpr void operator()(ck::f8x8_t &y, const ck::pk_i4x4_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:191
__host__ __device__ void operator()(Y &y, const X &x) const
__host__ __device__ constexpr void operator()(ck::bhalf8_t &y, const ck::pk_i4x4_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:240
constexpr static const bool is_pack8_invocable
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:265
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:160
__host__ __device__ constexpr void operator()(ck::half8_t &y, const ck::pk_i4x4_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:165
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1460
const float gamma_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1492
const float alpha_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1490
Power(float alpha=0.f, float beta=1.f, float gamma=2.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1463
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1467
const float beta_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1491
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1461
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1383
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1384
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1387
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:886
__host__ __device__ void operator()(bhalf_t &y, const bhalf_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:903
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:890
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:887
__host__ __device__ void operator()(Y &y, const X &x) const
float scale_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:697
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:684
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:686
float scale_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:679
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:642
__host__ __device__ Scale(float scale=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:639
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:637
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1061
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1062
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1065
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1091
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1088
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1271
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1272
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1275
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1175
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1176
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1179
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1430
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
const float alpha_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1456
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1431
SoftRelu(float alpha=1.f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1433
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1436
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1399
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1405
Swish(float beta=1.0f)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1402
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1400
const float beta_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1426
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1102
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1103
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1106
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1239
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1240
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1243
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:765
__host__ __device__ constexpr void operator()(Y &y, const X &x) const
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:769
__host__ __device__ void operator()(f8_t &y, const f8_t &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:781
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:766
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:566
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:567
__host__ __device__ void operator()(Y &y, const X &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:570
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:707
__host__ __device__ UnaryDivide(const int32_t divider=1)
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:704
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:702
int32_t divider_
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:743
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:797
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:801
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:798
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:747
__host__ __device__ void operator()(T &y, const T &x) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:751
static constexpr const char * name
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:748
Definition dtype_vector.hpp:10