reference_topk.hpp Source File

reference_topk.hpp Source File#

Composable Kernel: reference_topk.hpp Source File
reference_topk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9#include <numeric>
10#include <functional>
11#include <utility>
12#include <algorithm>
13
14namespace ck_tile {
15
16/*
17 similiar to torch.topk()
18 x (Tensor) – the input tensor.
19 k (int) – the k in “top-k”
20 dim (int, optional) – the dimension to sort along
21 largest (bool, optional) – largest or smallest elements
22 sorted (bool, optional) – elements in sorted order or not
23
24 output:
25 y_values
26 y_indices
27
28 https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
29*/
30template <typename DataType, typename IndexType = index_t>
32 HostTensor<DataType>& y_values,
33 HostTensor<IndexType>& y_indices,
34 index_t k,
35 index_t dim = -1,
36 bool largest = true,
37 bool sorted = true)
38{
39 // rank must be the same
41 assert(static_cast<std::size_t>(rank) == y_values.get_num_of_dimension());
42 assert(static_cast<size_t>(rank) == y_indices.get_num_of_dimension());
43 assert(dim == -1 || dim < rank);
44
45 index_t topk_dim = dim == -1 ? (rank - 1) : dim;
46 index_t topk_src_len = x.get_length(topk_dim);
47 auto x_len = x.get_lengths();
48
49 assert(k <= topk_src_len);
50 assert(static_cast<size_t>(k) == y_values.get_length(topk_dim) &&
51 static_cast<size_t>(k) == y_indices.get_length(topk_dim));
52
53 index_t n_parallel = x.get_element_size() / topk_src_len;
54
55 // clang-format off
56 auto f = [&](auto i_element) {
57 std::vector<size_t> topk_coord = [&](){
58 std::vector<size_t> t_(rank, 0);
59 size_t r = i_element;
60 for(index_t i = rank - 1; i >= 0; i--) {
61 if(i == topk_dim) continue; // topk dim should be zero
62 t_[i] = r % x_len[i]; r = r / x_len[i];
63 }
64 return t_;
65 }();
66
67 using elem_t = std::pair<DataType, IndexType>;
68 std::vector<elem_t> q = [&](){
69 std::vector<elem_t> t_(topk_src_len);
70 for(index_t i = 0; i < topk_src_len; i++) {
71 auto c_ = topk_coord; c_[topk_dim] = i;
72 t_[i].first = x(c_); t_[i].second = i;
73 }
74 return t_;
75 }();
76
77 // run topk
78 if(largest) {
79 std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
80 [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
81 if(sorted) {
82 std::sort(q.begin(), q.begin() + k - 1,
83 [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
84 }
85 } else {
86 std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
87 [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
88 if(sorted) {
89 std::sort(q.begin(), q.begin() + k - 1,
90 [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
91 }
92 }
93
94 // write out
95 for(index_t i = 0; i < k; i++) {
96 auto c_ = topk_coord; c_[topk_dim] = i;
97 y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
98 }
99 };
100 // clang-format on
101
102 make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
103}
104
105// TODO: if using this method, the return tensor would be dense(no stride)
106template <typename DataType, typename IndexType = index_t>
108 index_t k,
109 index_t dim = -1,
110 bool largest = true,
111 bool sorted = true)
112{
113 auto lens = x.get_lengths();
114 index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
115 assert(target_dim < lens.size());
116 assert(k <= lens[target_dim]);
117 lens[target_dim] = k;
118 HostTensor<DataType> y_values(lens);
119 HostTensor<IndexType> y_indices(lens);
120
121 reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
122
123 return ck_tile::make_tuple(y_values, y_indices);
124}
125} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_topk(const HostTensor< DataType > &x, HostTensor< DataType > &y_values, HostTensor< IndexType > &y_indices, index_t k, index_t dim=-1, bool largest=true, bool sorted=true)
Definition reference_topk.hpp:31
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396
std::size_t get_length(std::size_t dim) const
Definition tile/host/host_tensor.hpp:388
std::size_t get_element_size() const
Definition tile/host/host_tensor.hpp:398