elementwise_kernel.hpp Source File

elementwise_kernel.hpp Source File#

Composable Kernel: elementwise_kernel.hpp Source File
elementwise_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10namespace ck_tile {
11
12template <typename Problem_, typename Policy_>
14{
17
22
23 static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
24 CK_TILE_HOST static constexpr auto BlockSize()
25 {
26 return is_wave32() ? kBlockSize / 2 : kBlockSize;
27 }
28
29 template <typename... XDataType, typename Dims>
30 CK_TILE_DEVICE void operator()(const Dims lens,
31 const Dims input_strides,
32 const Dims output_strides,
33 const tuple<XDataType...>& input_tensors,
34 YDataType* p_y) const
35 {
36 using S = typename Problem::BlockShape;
37
38 // Setup block-level coordinates and transforms
39 const index_t iM = get_block_id() * S::kBlockM;
40 const auto merge_transform = make_merge_transform(lens);
41
42 // Load all input tiles into registers.
43 // The lambda structure here is intended to minimize the lifetime
44 // of intermediate objects (views, windows) used for loading.
45 const auto x_tiles = ck_tile::generate_tuple(
46 [&](auto i) {
48 input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
49
50 const auto transformed_tensor = pad_tensor_view(
52 ck_tile::make_tuple(merge_transform),
57
58 const auto x_window =
59 make_tile_window(transformed_tensor,
61 {iM},
62 Policy::template MakeXBlockTileDistribution<Problem>());
63
64 return load_tile(x_window);
65 },
66 number<sizeof...(XDataType)>{});
67
68 // Setup output tile in registers.
69 const auto& x_tile0 = x_tiles.get(number<0>{});
70 auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
71
72 // Perform element-wise computation.
73 const auto spans = x_tile0.get_distributed_spans();
74 sweep_tile_span(spans[number<0>{}], [&](auto idx) {
75 const auto tile_idx = make_tuple(idx);
76 apply(
77 [&](auto&&... tiles) {
78 ElementWiseOperation{}(y_tile(tile_idx),
79 type_convert<ComputeDataType>(tiles[tile_idx])...);
80 },
81 x_tiles);
82 });
83
84 // Setup output window and store the result tile.
86 p_y, lens, output_strides, number<S::kVectorM>{});
87
88 const auto transformed_y_m_n = pad_tensor_view(
90 ck_tile::make_tuple(merge_transform),
95
96 auto y_window = make_tile_window(transformed_y_m_n,
98 {iM},
99 y_tile.get_tile_distribution());
100
101 store_tile(y_window, cast_tile<YDataType>(y_tile));
102 }
103
104 template <typename... Ints>
106 {
107 // when total elements % kVectorM != 0; should use Pad instead of unsupported
108 ignore = input_sizes;
109 return true;
110 }
111};
112
113} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition tile/core/container/tuple.hpp:526
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
typename __make_integer_seq< impl::__integer_sequence, index_t, N >::seq_type make_index_sequence
Definition tile/core/container/sequence.hpp:230
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition elementwise_kernel.hpp:14
ck_tile::remove_cvref_t< Problem_ > Problem
Definition elementwise_kernel.hpp:15
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition elementwise_kernel.hpp:18
ck_tile::remove_cvref_t< Policy_ > Policy
Definition elementwise_kernel.hpp:16
ck_tile::remove_cvref_t< typename Problem::ElementWiseOperation > ElementWiseOperation
Definition elementwise_kernel.hpp:21
CK_TILE_DEVICE void operator()(const Dims lens, const Dims input_strides, const Dims output_strides, const tuple< XDataType... > &input_tensors, YDataType *p_y) const
Definition elementwise_kernel.hpp:30
static constexpr index_t kBlockSize
Definition elementwise_kernel.hpp:23
static CK_TILE_HOST constexpr auto BlockSize()
Definition elementwise_kernel.hpp:24
static CK_TILE_HOST bool IsSupportedArgument(const ck_tile::tuple< Ints... > &input_sizes)
Definition elementwise_kernel.hpp:105
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition elementwise_kernel.hpp:20
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition elementwise_kernel.hpp:19
Definition tile/core/container/sequence.hpp:49
Definition tensor_view.hpp:41
Definition tile/core/container/tuple.hpp:192
CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const &
Definition tile/core/container/tuple.hpp:269