reference_softmax.hpp Source File

reference_softmax.hpp Source File#

Composable Kernel: reference_softmax.hpp Source File
reference_softmax.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
13CK_TILE_HOST void
15{
17 assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
18 assert(dim == -1 || dim < rank);
19
20 index_t target_dim = dim == -1 ? (rank - 1) : dim;
21 index_t softmax_len = x.get_length(target_dim);
22 index_t n_parallel = x.get_element_size() / softmax_len;
23 auto x_len = x.get_lengths();
24
25 auto f = [&](auto i_element) {
26 std::vector<size_t> coord = [&]() {
27 std::vector<size_t> t_(rank, 0);
28 size_t r = i_element;
29 for(index_t i = rank - 1; i >= 0; i--)
30 {
31 if(i == target_dim)
32 continue;
33 t_[i] = r % x_len[i];
34 r = r / x_len[i];
35 }
36 return t_;
37 }();
38
39 ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
40
41 // compute max
42 for(auto idx = 0; idx < softmax_len; idx++)
43 {
44 auto c_ = coord;
45 c_[target_dim] = idx;
46 const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
47 v_max = v_max < v_x ? v_x : v_max;
48 }
49
50 ComputeType v_exp_sum = static_cast<ComputeType>(0);
51
52 // sum
53 for(auto idx = 0; idx < softmax_len; idx++)
54 {
55 auto c_ = coord;
56 c_[target_dim] = idx;
57
58 const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
59
60 v_exp_sum += ck_tile::exp(v_x - v_max);
61 }
62
63 // elementwise
64 for(auto idx = 0; idx < softmax_len; idx++)
65 {
66 auto c_ = coord;
67 c_[target_dim] = idx;
68
69 const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
70
71 auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
72
74 }
75 };
76
77 make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
78}
79
80template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
89} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_softmax(const HostTensor< InputType > &x, HostTensor< OutputType > &y, index_t dim=-1)
Definition reference_softmax.hpp:14
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
decltype(auto) get_strides() const
Definition tile/host/host_tensor.hpp:394
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396
std::size_t get_length(std::size_t dim) const
Definition tile/host/host_tensor.hpp:388
std::size_t get_element_size() const
Definition tile/host/host_tensor.hpp:398
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38