batched_transpose_lds_policy.hpp Source File

batched_transpose_lds_policy.hpp Source File#

Composable Kernel: batched_transpose_lds_policy.hpp Source File
batched_transpose_lds_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
12{
13 template <typename Problem>
15 {
17 sizeof(typename Problem::DataType) *
18 MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
19 16);
20 }
21
22 template <typename Problem>
23 CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
24 {
25 constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
26
27 using OutTileDstrEncode =
28 typename OutputTileDistributionTraits<typename decltype(input_dstr)::DstrEncode,
29 typename Problem::DataType>::TransposedDstrEncode;
30 constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
31
32 return block_dstr;
33 }
34
35 template <typename Problem>
37 {
38 constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
39 constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
40 constexpr index_t kVectorSize = Problem::LDSVectorSize;
41
42 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
44 number<kLeadDimPerBlock / kVectorSize>{},
48 number<1>{});
49
50 constexpr auto lds_block_desc = transform_tensor_descriptor(
51 lds_block_desc_0,
57
58 return lds_block_desc;
59 }
60
61 template <typename Problem>
63 {
64 constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
65 constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
66 constexpr index_t kVectorSize = Problem::LDSVectorSize;
67
68 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
70 number<kLeadDimPerBlock / kVectorSize>{},
74 number<1>{});
75
76 constexpr auto lds_block_desc = transform_tensor_descriptor(
77 lds_block_desc_0,
83
84 return lds_block_desc;
85 }
86
87 template <typename Problem>
89 {
90 using DataType = typename Problem::DataType;
91
92 // Calculate block-level dimensions
93 constexpr index_t kLeadIterPerWarp = 1;
94 constexpr index_t kSecondIterPerWarp = 1;
95 constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
96 constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
97
98 // Calculate repetitions of base pattern
99 constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim;
100 constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim;
101 constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
102 constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
103
104 constexpr index_t kLaneGroupSize = 16;
105 constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
106 kLaneGroupSize,
107 kSecondDimStrSub,
108 kSecondDimIterations,
109 kLeadRepetitions,
110 1>();
111
112 constexpr auto input_tile_encode =
113 InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
114 kLeadIterPerWarp,
115 kSecondIterPerWarp,
116 kLeadNumWarps,
117 kSecondNumWarps>();
118 constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
119 return block_dstr;
120 }
121};
122
123} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
Definition amd_transpose_load_encoding.hpp:82
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, false > OutputTileDistributionTraits
Definition load_tile_transpose.hpp:338
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
Definition load_tile_transpose.hpp:351
Definition batched_transpose_common_policy.hpp:11
Definition batched_transpose_lds_policy.hpp:12
static CK_TILE_DEVICE constexpr auto MakeLdsLoadBlockDescriptor()
Definition batched_transpose_lds_policy.hpp:62
static CK_TILE_DEVICE constexpr auto MakeLdsStoreBlockDescriptor()
Definition batched_transpose_lds_policy.hpp:36
static CK_TILE_DEVICE constexpr index_t GetSmemSize()
Definition batched_transpose_lds_policy.hpp:14
static CK_TILE_DEVICE constexpr auto MakeOutputDistribution()
Definition batched_transpose_lds_policy.hpp:23
static CK_TILE_DEVICE constexpr auto MakeLdsLoadTileDistribution()
Definition batched_transpose_lds_policy.hpp:88
Definition tile/core/container/sequence.hpp:49