tile_distribution.hpp Source File

tile_distribution.hpp Source File#

Composable Kernel: tile_distribution.hpp Source File
tile_distribution.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck_tile {
19
20namespace detail {
21template <typename Distribution>
23{
24 return Distribution::_get_partition_index();
25}
26} // namespace detail
27
28// distributed span
29template <index_t... PartialHsLengths>
31{
32 using Impl = sequence<PartialHsLengths...>;
33
34 static constexpr auto impl_ = Impl{};
35
36 CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
37};
38
39// distributed index
40template <index_t... PartialHsIndices>
42{
43 using Impl = sequence<PartialHsIndices...>;
44
45 static constexpr auto impl_ = Impl{};
46
47 CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
48};
49
50namespace detail {
51
52template <index_t... Is>
57
58template <index_t... Is>
63
64} // namespace detail
65
66template <typename PsYs2XsAdaptor_,
67 typename Ys2DDescriptor_,
68 typename StaticTileDistributionEncoding_,
69 typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
70 // should be more elegnat
72{
77
78 static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
79 "wrong! should be static");
80
81 static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
82 static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
83 static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
84 static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
85
88
93
95 {
96 // only support warp-tile and block-tile
97 static_assert(NDimP == 1 or NDimP == 2, "wrong!");
98
99 if constexpr(NDimP == 1)
100 {
102 }
103 else if constexpr(NDimP == 2)
104 {
106 }
107 }
108
109 CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
110 {
111#if 0
112 // FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
113 ps_ys_to_xs_.GetBottomDimensionLengths();
114#else
115 return generate_tuple(
116 [&](auto i) {
117 constexpr index_t x_length =
118 container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
119
120 return number<x_length>{};
121 },
122 number<NDimX>{});
123#endif
124 }
125
126 CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
127 {
128 return ps_ys_to_xs_;
129 }
130
131 CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
132
134 {
135 return DstrEncode{};
136 }
137
138#if 1
139 // Calculate Replication index [R0, R1, ...] based on Partion index
140 // FIXME: very nasty implementation
141 template <typename PartitionIndex>
142 CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
143 {
144 static_assert(PartitionIndex::size() == NDimP, "wrong!");
145
146 const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
147
148 const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
149
151
152 static_for<0, NDimP, 1>{}([&](auto idim_p) {
153 constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
154
155 static_for<0, ndim_low, 1>{}([&](auto i) {
156 constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
157 constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
158
159 // 0-th rh_major is the replicate dimension
160 if constexpr(rh_major == 0)
161 {
162 constexpr index_t adaptor_hidden_id =
163 DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
164
165 // fill in
166 rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
167 }
168 });
169 });
170
171 return rs_idx;
172 }
173#endif
174
175 template <typename PartitionIndex = decltype(_get_partition_index())>
177 calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
178 {
179 const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
180 const auto window_adaptor_thread_coord_tmp =
182 return window_adaptor_thread_coord_tmp.get_bottom_index();
183 }
184
186 {
187 constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
188 constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
189
190 return generate_tuple(
191 [&](auto i) {
192 constexpr auto span_impl = distributed_spans_impl[i];
193 constexpr index_t ndim_span_minor = ndims_spans_minor[i];
194
195 constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
196
198 },
199 number<NDimX>{});
200 }
201
202 // FIXME: it's hacky to get Y index from Distributed-Index
203 template <typename DistributedIndices>
204 CK_TILE_HOST_DEVICE static constexpr auto
206 {
207 constexpr auto ys_idx_arr = [] {
209
210 static_for<0, NDimY, 1>{}([&](auto i) {
211 constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
212 constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
213
214 constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
215
216 ys_idx(i) = dstr_index.impl_[span_minor];
217 });
218
219 return ys_idx;
220 }();
221
222 constexpr index_t ndim_y = NDimY;
223
224 return TO_SEQUENCE(ys_idx_arr, ndim_y);
225 }
226
227 CK_TILE_HOST_DEVICE static constexpr bool is_static()
228 {
229 return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
230 }
231};
232
233namespace detail {
234
235template <index_t NDimMax>
237{
239
240 for(index_t i = 0; i < iend - ibegin; ++i)
241 {
242 arr(i) = ibegin + i;
243 }
244
245 return arr;
246}
247
248// this returns a constexpr encoding of tile_distribution
249template <typename StaticTileDistributionEncoding_>
250CK_TILE_HOST_DEVICE constexpr auto
251make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
252{
253 using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
254 using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
255 using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
256 using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
257 using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
258 using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
259
260 // FIXME: increase max value if fail
261 constexpr index_t kMaxNumTransforms = 20;
262 constexpr index_t kMaxMetaDataSize = 128;
263 constexpr index_t kMaxNumDim = 10;
264
265 using Name = coord_transform_enum;
266 using MetaData = meta_data_buffer<kMaxMetaDataSize>;
267 using NumDim = index_t;
268 using Dims = array<index_t, kMaxNumDim>;
269 using Lengths = array<index_t, kMaxNumDim>;
270
271 // Tile Adaptor
272 // bottom dims [x0, x1, x2, ...]
273 // top dims [p0, p1, ..., y0, y1, ...]
274 constexpr index_t ndim_x = HsLengthss::size();
275
276 // Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
277 array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
278 array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
279
280 auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
281
282 index_t num_tran = 0;
283 index_t hidden_dim_cnt = ndim_x;
284
285 // this is replicate transform
286 {
287 constexpr index_t ndim_r_minor = RsLengths::size();
288
289 constexpr auto r_minor_lengths = RsLengths{};
290
291 trans(num_tran++) = {
293 MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
294 NumDim{0},
295 Dims{},
296 NumDim{ndim_r_minor},
297 make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
298
299 for(index_t i = 0; i < ndim_r_minor; ++i)
300 {
301 rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
302 rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
303
304 hidden_dim_cnt++;
305 }
306 };
307
308 // these are Unmerge transforms for X dimesions
309 static_for<0, ndim_x, 1>{}([&trans,
310 &num_tran,
311 &hidden_dim_cnt,
312 &rh_major_minor_to_hidden_ids,
313 &rh_major_minor_to_hidden_lengths](auto idim_x) {
314 // typename HsLengthss::base{}.foo();
315 constexpr auto h_minor_lengths =
316 HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
317 // constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
318
319 constexpr index_t ndim_h_minor = h_minor_lengths.size();
320
321 trans(num_tran++) = {
323 MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
324 NumDim{1},
325 Dims{idim_x},
326 NumDim{ndim_h_minor},
327 make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
328
329 for(index_t i = 0; i < ndim_h_minor; ++i)
330 {
331 rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
332 rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
333
334 hidden_dim_cnt++;
335 }
336 });
337
338 // transform: P dimensions
339 constexpr index_t ndim_p = Ps2RHssMajor::size();
340
341 Dims hidden_dim_id_ps;
342
343 static_for<0, ndim_p, 1>{}([&](auto iDimP) {
344 //
345 index_t hidden_dim_id_p = hidden_dim_cnt++;
346
347 hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
348
349 constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
350 constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
351
352 static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
353
354 constexpr index_t ndim_low = p2RHsMajor.size();
355
356 Dims low_dims;
357 Lengths low_lengths;
358
359 for(index_t i = 0; i < ndim_low; ++i)
360 {
361 index_t rh_major = p2RHsMajor[i];
362 index_t rh_minor = p2RHsMinor[i];
363 low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
364 low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
365 }
366
367 trans(num_tran++) = {coord_transform_enum::merge,
368 MetaData{to_array<index_t, ndim_low>(low_lengths)},
369 NumDim{ndim_low},
370 low_dims,
371 NumDim{1},
372 Dims{hidden_dim_id_p}};
373 });
374
375 constexpr index_t ndim_bottom = ndim_x;
376
377 constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
378
379 constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
380 constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
381
382 constexpr index_t ndim_y = Ys2RHsMajor::size();
383 constexpr index_t ndim_top = ndim_p + ndim_y;
384
385 auto top_dim_ids = hidden_dim_id_ps;
386
387 {
388 for(index_t i = 0; i < ndim_y; ++i)
389 {
390 index_t rh_major = ys_to_rhs_major[i];
391 index_t rh_minor = ys_to_rhs_minor[i];
392 top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
393 }
394 }
395
396 //
397 const auto ps_ys_to_xs_adaptor_encoding =
398 make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
399
400 // descriptor: [y0, y1, ...] to [d]
401 Lengths y_lengths;
402 index_t d_length = 1;
403
404 for(index_t i = 0; i < ndim_y; ++i)
405 {
406 index_t rh_major = ys_to_rhs_major[i];
407 index_t rh_minor = ys_to_rhs_minor[i];
408 index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
409 y_lengths(i) = y_length;
410 d_length *= y_length;
411 }
412
414 MetaData{to_array<index_t, ndim_y>(y_lengths)},
415 NumDim{1},
416 Dims{0},
417 NumDim{ndim_y},
418 make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
419
420 const auto ys_to_d_adaptor_encoding = make_tuple(
421 make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
422
423 return make_tuple(ps_ys_to_xs_adaptor_encoding,
424 ys_to_d_adaptor_encoding,
425 d_length,
426 rh_major_minor_to_hidden_ids);
427}
428
429// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
430template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
432{
434 to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
435};
436
437} // namespace detail
438
439#if 0
440// this returns a constexpr tile_distribution
441template <typename StaticTileDistributionEncoding_>
442CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
443{
445
446 constexpr auto adaptor_impl =
447 detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
448
449 constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
450 constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
451 constexpr index_t d_length = adaptor_impl.template at<2>();
452 constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
453
454 constexpr auto ps_ys_to_xs_adaptor =
455 CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
456
457 constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
458
459 constexpr auto ys_to_d_descriptor =
460 make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
461
462 //
463 constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
464 constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
465
466 constexpr auto rh_major_minor_to_hidden_ids =
467 TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
468
469 return tile_distribution<
470 remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
471 remove_cvref_t<decltype(ys_to_d_descriptor)>,
473 detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
474 ps_ys_to_xs_adaptor, ys_to_d_descriptor};
475}
476#endif
477
478// this returns a static tile_distribution
479template <typename StaticTileDistributionEncoding_>
480CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
481{
483
484 constexpr auto adaptor_impl =
485 detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
486
487 constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
488 constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
489 constexpr index_t d_length = adaptor_impl.template at<2>();
490 constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
491
492 constexpr auto ps_ys_to_xs_adaptor =
493 CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
494
495 constexpr auto ys_to_d_adaptor =
497
498 constexpr auto ys_to_d_descriptor =
500
501 //
502 constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
503 constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
504
505 constexpr auto rh_major_minor_to_hidden_ids =
506 TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
507
508 return tile_distribution<
509 remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
510 remove_cvref_t<decltype(ys_to_d_descriptor)>,
512 detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
513 ps_ys_to_xs_adaptor, ys_to_d_descriptor};
514}
515
516//***********************************************************************************
517
518namespace detail {
519//
520// slice tensor from x_dim, result in split in y_dim, not p_dim.
521// We don't support slice cross p_dim (aka, slice different threads)
522// also, sliced along y_dim need be the first dim of current dim.
523// Multiply Y dim before sliced dim does not make sense
524//
525// e.g
526// X0 X1
527// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
528// Y P P Y P Y P Y
529// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
530// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
531//
532// X0 X1
533// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
534// Y P P Y P Y P Y
535// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
536// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
537// totally 16 slices
538//
539// X0 X1
540// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
541// Y P P Y P Y P Y
542// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
543// |--> slice along this P dim, will split threads, not supported
544//
545// X0 X1
546// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
547// Y P P Y P Y P Y
548// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
549// |--> slice along this Y dim, but this Y sim need to split into 2
550// subdime
551// the P dim in the left is 1, means actually not crossing P
552//
553template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
555 Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
556{
557 // NOTE: this function need to be called under constexpr context,
558 // due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
559 using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
560
561 static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
562 static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
563
564 constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
565
566 constexpr auto x_slice_ends_ = generate_sequence_v2(
567 [&](auto i) {
568 if constexpr(x_slice_ends[i] == -1)
569 {
570 // -1 means till the end
571 constexpr auto x_length_ =
572 container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
573 return x_length_;
574 }
575 else
576 {
577 return x_slice_ends[i];
578 }
579 },
580 number<x_slice_ends.size()>{});
581
582 constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
583
584 constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
585 [&](auto i) constexpr {
586 constexpr auto len_ = x_slice_lengths[i];
587 static_assert(len_ % p_len_over_h[i] == 0,
588 "slice length must be dividable by p_len_over_h");
589 return number<len_ / p_len_over_h[i]>{};
590 },
591 number<x_slice_lengths.size()>{});
592
593 constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
594 constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
595 constexpr auto src_y_dims = src_y_info[number<0>{}];
596 constexpr auto src_y_maps = src_y_info[number<1>{}];
597 constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
598
599 constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
600 auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
601 auto y_slice_lengths = Encoding::detail::ys_lengths_;
602 constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
603
604 // This lambda will modify some value outside, so c++ will not treat return value as
605 // constexpr
606 // TODO: ugly
607 auto new_h_lengths = transform_tuples(
608 [&](auto h_len, auto id) {
609 constexpr auto sliced_h = reverse_slice_sequence(
610 h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
611
612 constexpr auto sliced_h_lens = sliced_h[number<0>{}];
613 constexpr auto sliced_h_index = sliced_h[number<2>{}];
614
615 // update y_slice_lengths
616 constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
617 constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
618 constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
619
620 static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
621 "not sliced at y dim, please check");
622
623 {
624 constexpr auto sliced_y_to_h_lens =
625 pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
626 constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
628 y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
629 sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
630 });
631 }
632 // TODO: add validations not across p dim
633
634 // NOTE: this y_origin is for all dims, not only current dim
635 // will later use pick to select target dim
636 constexpr auto y_origin = [&]() {
637 // can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
638 constexpr auto y_to_h_len =
639 pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
640 constexpr auto y_to_h_dims = y_to_h_len.size();
641
642 constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
643 auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
644 constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
645 h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
646
648
649 static_for<0, y_to_h_dims, 1>{}([&](auto i) {
650 y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
651 });
652 return y_origin_;
653 }();
654
655 constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
656 src_y_prefix_sum[id + 1],
657 1>::type{};
658
660 y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
661 return sliced_h_lens;
662 },
663 typename Encoding::HsLengthss{},
664 typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
665
666 auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
667
668 return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
669 }();
670
671 constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
672 constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
673 constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
674 constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
675 constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
676
677 constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
678 constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
679
680 return make_tuple(
682 tile_distribution_encoding<typename Encoding::RsLengths,
683 remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
684 // change the
685 // h_lengths type
686 typename Encoding::Ps2RHssMajor,
687 typename Encoding::Ps2RHssMinor,
688 typename Encoding::Ys2RHsMajor,
689 typename Encoding::Ys2RHsMinor>{}),
690 sliced_y_origins,
691 sliced_y_lengths);
692}
693
694} // namespace detail
695
696// Free print function for tile_distribution
697template <typename PsYs2XsAdaptor_,
698 typename Ys2DDescriptor_,
699 typename StaticTileDistributionEncoding_,
700 typename TileDistributionDetail_>
701CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
702 Ys2DDescriptor_,
703 StaticTileDistributionEncoding_,
704 TileDistributionDetail_>& distribution)
705{
706 printf("tile_distribution{");
707 printf("tile_distribution_encoding: ");
708 print(StaticTileDistributionEncoding_{});
709 printf(", ");
710 printf("ps_ys_to_xs_: ");
711 print(distribution.ps_ys_to_xs_);
712 printf(", ");
713 printf("ys_to_d_: ");
714 print(distribution.ys_to_d_);
715 printf("}\n");
716}
717
718} // namespace ck_tile
Definition tile/core/container/span.hpp:18
Concept for encoding of Unicode characters.
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition tile_distribution.hpp:554
CK_TILE_HOST_DEVICE constexpr auto make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:251
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
Definition tile_distribution.hpp:236
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence< Is... >)
Definition tile_distribution.hpp:53
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence< Is... >)
Definition tile_distribution.hpp:59
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition tile_distribution.hpp:22
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
coord_transform_enum
Definition coordinate_transform.hpp:17
@ replicate
Definition coordinate_transform.hpp:24
@ unmerge
Definition coordinate_transform.hpp:23
@ merge
Definition coordinate_transform.hpp:22
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
Definition tile/core/container/sequence.hpp:945
CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1045
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition tile/core/container/container_helper.hpp:48
constexpr auto reverse_slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition tile/core/container/sequence.hpp:1223
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X &x)
Definition tile/core/container/tuple.hpp:505
CK_TILE_HOST_DEVICE constexpr auto make_tensor_descriptor_from_adaptor(const Adaptor &adaptor, const ElementSpaceSize &element_space_size)
Definition tile/core/tensor/tensor_descriptor.hpp:183
CK_TILE_HOST_DEVICE constexpr void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition tile/core/container/container_helper.hpp:420
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector< X > &x)
Definition tile/core/container/array.hpp:286
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
Definition tile/core/container/multi_index.hpp:26
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition tile/core/container/container_helper.hpp:447
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition tile/core/container/tuple.hpp:630
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition tile/core/container/sequence.hpp:287
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile_distribution.hpp:432
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_
Definition tile_distribution.hpp:433
Definition meta_data_buffer.hpp:16
Definition tile/core/numeric/math.hpp:98
Definition tile/core/container/sequence.hpp:49
static CK_TILE_HOST_DEVICE constexpr index_t size()
Definition tile/core/container/sequence.hpp:53
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution.hpp:42
static constexpr auto impl_
Definition tile_distribution.hpp:45
sequence< PartialHsIndices... > Impl
Definition tile_distribution.hpp:43
static CK_TILE_HOST_DEVICE constexpr bool is_static()
Definition tile_distribution.hpp:47
Definition tile_distribution.hpp:31
static CK_TILE_HOST_DEVICE constexpr bool is_static()
Definition tile_distribution.hpp:36
sequence< PartialHsLengths... > Impl
Definition tile_distribution.hpp:32
static constexpr auto impl_
Definition tile_distribution.hpp:34
Definition tile_distribution_encoding.hpp:26
Definition tile_distribution.hpp:72
remove_cvref_t< Ys2DDescriptor_ > Ys2DDescriptor
Definition tile_distribution.hpp:74
PsYs2XsAdaptor ps_ys_to_xs_
Definition tile_distribution.hpp:86
static CK_TILE_HOST_DEVICE auto _get_partition_index()
Definition tile_distribution.hpp:94
static constexpr index_t NDimY
Definition tile_distribution.hpp:82
static CK_TILE_HOST_DEVICE constexpr auto get_lengths()
Definition tile_distribution.hpp:109
static CK_TILE_HOST_DEVICE constexpr auto get_distributed_spans()
Definition tile_distribution.hpp:185
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_dimension_x()
Definition tile_distribution.hpp:89
remove_cvref_t< StaticTileDistributionEncoding_ > DstrEncode
Definition tile_distribution.hpp:75
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_dimension_p()
Definition tile_distribution.hpp:91
CK_TILE_HOST_DEVICE constexpr const auto & get_ys_to_d_descriptor() const
Definition tile_distribution.hpp:131
CK_TILE_HOST_DEVICE constexpr const auto & get_ps_ys_to_xs_adaptor() const
Definition tile_distribution.hpp:126
remove_cvref_t< TileDistributionDetail_ > DstrDetail
Definition tile_distribution.hpp:76
CK_TILE_HOST_DEVICE auto calculate_index(const PartitionIndex &ps_idx=_get_partition_index()) const
Definition tile_distribution.hpp:177
static constexpr index_t NDimP
Definition tile_distribution.hpp:83
static CK_TILE_HOST_DEVICE constexpr auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition tile_distribution.hpp:205
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex &ps_idx) const
Definition tile_distribution.hpp:142
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_dimension_r()
Definition tile_distribution.hpp:92
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_dimension_y()
Definition tile_distribution.hpp:90
remove_cvref_t< PsYs2XsAdaptor_ > PsYs2XsAdaptor
Definition tile_distribution.hpp:73
static CK_TILE_HOST_DEVICE constexpr bool is_static()
Definition tile_distribution.hpp:227
static constexpr index_t NDimR
Definition tile_distribution.hpp:84
Ys2DDescriptor ys_to_d_
Definition tile_distribution.hpp:87
static CK_TILE_HOST_DEVICE constexpr auto get_static_tile_distribution_encoding()
Definition tile_distribution.hpp:133
static constexpr index_t NDimX
Definition tile_distribution.hpp:81
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition tile/core/container/container_helper.hpp:486
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition tile/core/tensor/tensor_adaptor.hpp:840
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition tile/core/tensor/tensor_adaptor.hpp:716
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10