functional3.hpp Source File

functional3.hpp Source File#

Composable Kernel: functional3.hpp Source File
functional3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
11
12namespace ck {
13
14namespace detail {
15
16// RemainLengths: Sequence<...>
17// Orders: Sequence<...>
18template <class RemainLengths, class Orders>
20{
21 __host__ __device__ constexpr static_ford_impl()
22 {
23 static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
24 }
25
26 // F signature: F(Sequence<...>)
27 // CurrentOrderedId: Sequence<...>
28 template <class F, class CurrentOrderedId>
29 __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
30 {
31 static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
32 static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
33 f, CurrentOrderedId::PushBack(I));
34 });
35 }
36};
37
38template <class Orders>
39struct static_ford_impl<Sequence<>, Orders>
40{
41 // F signature: F(Sequence<...>)
42 // OrderedId: Sequence<...>
43 template <class F, class OrderedId>
44 __host__ __device__ constexpr void operator()(F f, OrderedId) const
45 {
46 // retrive unordered Id
47 f(OrderedId::ReorderGivenOld2New(Orders{}));
48 }
49};
50
51// RemainLengths: Sequence<...>
52// Orders: Sequence<...>
53template <class RemainLengths, class Orders>
55{
56 __host__ __device__ constexpr ford_impl()
57 {
58 static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
59 }
60
61 // F signature: F(Array<...> multi_id)
62 // CurrentOrderdId: Array<...>
63 template <class F, class CurrentOrderedId>
64 __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
65 {
66 for(index_t i = 0; i < RemainLengths::Front(); ++i)
67 {
68 ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
69 f, container_push_back(current_ordered_id, i));
70 }
71 }
72};
73
74template <class Orders>
75struct ford_impl<Sequence<>, Orders>
76{
77 // F signature: F(Array<...> multi_id)
78 // CurrentOrderdId: Array<...>
79 template <class F, class CurrentOrderedId>
80 __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
81 {
82 // retrive unordered Id
83 f(container_reorder_given_old2new(current_ordered_id, Orders{}));
84 }
85};
86
87} // namespace detail
88
89// Lengths is Sequence<...>, it is the length of each dimension for
90// N-dimensional loop
91// Orders is Sequence<...>, it is the order of dimension in which static_ford
92// will loop over each
93// dimension
94template <class Lengths,
95 class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
97{
98 __host__ __device__ constexpr static_ford()
99 {
100 static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
101 static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
102 }
103
104 // F signature: F(Sequence<...> multi_id)
105 // multi_id is the unordered multi-index
106 template <class F>
107 __host__ __device__ constexpr void operator()(F f) const
108 {
109 constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
110 detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
111 }
112};
113
114// Lengths is Sequence<...>, it is the length of each dimension for
115// N-dimensional loop
116// Orders is Sequence<...>, it is the order of dimension in which ford will loop
117// over each
118// dimension
119template <class Lengths,
120 class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
121struct ford
122{
123 __host__ __device__ constexpr ford()
124 {
125 static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
126 static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
127 }
128
129 // F signature: F(Array<...> multi_id)
130 // multi_id is the unordered multi-index
131 template <class F>
132 __host__ __device__ constexpr void operator()(F f) const
133 {
134 constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
135
136 for(index_t i = 0; i < ordered_lengths.Front(); ++i)
137 {
138 detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
140 }
141 }
142};
143
144} // namespace ck
Definition threadwise_tensor_slice_transfer_util.hpp:15
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
Definition utility/sequence.hpp:43
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Definition functional3.hpp:80
Definition functional3.hpp:55
__host__ __device__ constexpr ford_impl()
Definition functional3.hpp:56
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Definition functional3.hpp:64
__host__ __device__ constexpr void operator()(F f, OrderedId) const
Definition functional3.hpp:44
Definition functional3.hpp:20
__host__ __device__ constexpr static_ford_impl()
Definition functional3.hpp:21
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
Definition functional3.hpp:29
__host__ __device__ constexpr ford()
Definition functional3.hpp:123
__host__ __device__ constexpr void operator()(F f) const
Definition functional3.hpp:132
Definition functional2.hpp:33
__host__ __device__ constexpr static_ford()
Definition functional3.hpp:98
__host__ __device__ constexpr void operator()(F f) const
Definition functional3.hpp:107