coordinate_transform.hpp Source File

coordinate_transform.hpp Source File#

Composable Kernel: coordinate_transform.hpp Source File
coordinate_transform.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck_tile {
15
29
30template <index_t NDimLow, index_t NDimUp>
32{
33 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
34 {
36 }
37
38 CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; }
39
40 CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; }
41
42 // return safe value for vector length/stride, based on compile-time known only
43 // variables
44 // MUST be static function
45 template <typename LowVectorLengths, typename LowVectorStrides>
46 CK_TILE_HOST_DEVICE static constexpr auto
48 const LowVectorStrides&)
49 {
50 if constexpr(NDimUp > 0)
51 {
52 array<index_t, NDimUp> up_vector_lengths{-1};
53 array<index_t, NDimUp> up_vector_strides{-1};
54
55 return make_tuple(up_vector_lengths, up_vector_strides);
56 }
57 else
58 {
60 }
61 }
62};
63
64template <typename LowLength>
65struct pass_through : public base_transform<1, 1>
66{
68
71
72 using UpLengths = decltype(make_tuple(LowLength{}));
73
75
76 CK_TILE_HOST_DEVICE constexpr pass_through() = default;
77
78 CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length)
79 : up_lengths_{make_tuple(low_length)}
80 {
81 }
82
83 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
84 {
86 }
87
88 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
89
90 template <typename LowIdx, typename UpIdx>
91 CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
92 const UpIdx& idx_up)
93 {
94 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
95 "wrong! inconsistent # of dimension");
96
97 idx_low(number<0>{}) = idx_up[number<0>{}];
98 }
99
100 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
101 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
102 const UpIdxDiff& idx_diff_up,
103 LowIdx& idx_low,
104 const UpIdx&)
105 {
106 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
107 UpIdx::size() == 1,
108 "wrong! inconsistent # of dimension");
109
110 constexpr auto I0 = number<0>{};
111
112 idx_diff_low[I0] = idx_diff_up[I0];
113
114 idx_low += idx_diff_low;
115 }
116
117 CK_TILE_HOST_DEVICE static constexpr bool
122
123 template <typename UpIdx>
124 CK_TILE_HOST_DEVICE static constexpr bool
126 {
127 return true;
128 }
129
134
135 // MUST be static function
136 template <typename LowVectorLengths, typename LowVectorStrides>
137 CK_TILE_HOST_DEVICE static constexpr auto
138 calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
139 const LowVectorStrides& low_vector_strides)
140 {
141 return make_tuple(low_vector_lengths, low_vector_strides);
142 }
143};
144
145template <typename LowLength>
146CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
147{
148 printf("pass_through{");
149
150 printf("up_lengths_: ");
151 print(pt.get_upper_lengths());
152
153 printf("}");
154}
155
156template <typename LowLength,
157 typename LeftPadLength,
158 typename RightPadLength,
159 bool SkipIsValidCheck = false>
160struct pad : public base_transform<1, 1>
161{
164
165 using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
166
168 LeftPadLength left_pad_length_;
169 RightPadLength right_pad_length_;
170
172
173 CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length,
174 const LeftPadLength& left_pad_length,
175 const RightPadLength& right_pad_length)
176 : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
177 left_pad_length_{left_pad_length},
178 right_pad_length_{right_pad_length}
179 {
180 }
181
182 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
183
184 template <typename LowIdx, typename UpIdx>
185 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
186 const UpIdx& idx_up) const
187 {
188 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
189 "wrong! inconsistent # of dimension");
190
191 idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
192 }
193
194 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
195 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
196 const UpIdxDiff& idx_diff_up,
197 LowIdx& idx_low,
198 const UpIdx&)
199 {
200 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
201 UpIdx::size() == 1,
202 "wrong! inconsistent # of dimension");
203
204 constexpr auto I0 = number<0>{};
205
206 idx_diff_low[I0] = idx_diff_up[I0];
207
208 idx_low += idx_diff_low;
209 }
210
211 CK_TILE_HOST_DEVICE static constexpr bool
213 {
214 return SkipIsValidCheck;
215 }
216
217 template <typename UpIdx>
218 CK_TILE_HOST_DEVICE constexpr bool
220 {
221 return SkipIsValidCheck ||
222 ((idx_up[number<0>{}] >= left_pad_length_) &&
224 }
225
232};
233
234template <typename LowLength,
235 typename LeftPadLength,
236 typename RightPadLength,
237 bool SkipIsValidCheck>
238CK_TILE_HOST_DEVICE static void
240{
241 printf("pad{");
242 printf("up_lengths_: ");
243 print(p.up_lengths_);
244 printf(", left_pad_length_: ");
245 print(p.left_pad_length_);
246 printf(", right_pad_length_: ");
247 print(p.right_pad_length_);
248 printf("}");
249}
250
251template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
253{
256
257 using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
258
260 LeftPadLength left_pad_length_;
261
262 CK_TILE_HOST_DEVICE constexpr left_pad() = default;
263
264 CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length,
265 const LeftPadLength& left_pad_length)
266 : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
267 {
268 }
269
270 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
271
272 template <typename LowIdx, typename UpIdx>
273 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
274 const UpIdx& idx_up) const
275 {
276 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
277 "wrong! inconsistent # of dimension");
278
279 idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
280 }
281
282 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
283 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
284 const UpIdxDiff& idx_diff_up,
285 LowIdx& idx_low,
286 const UpIdx&)
287 {
288 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
289 UpIdx::size() == 1,
290 "wrong! inconsistent # of dimension");
291
292 constexpr auto I0 = number<0>{};
293
294 idx_diff_low[I0] = idx_diff_up[I0];
295
296 idx_low += idx_diff_low;
297 }
298
299 CK_TILE_HOST_DEVICE static constexpr bool
301 {
302 return SkipIsValidCheck;
303 }
304
305 template <typename UpIdx>
306 CK_TILE_HOST_DEVICE constexpr bool
308 {
309 return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_);
310 }
311
317
318 // MUST be static function
319 template <typename LowVectorLengths, typename LowVectorStrides>
320 CK_TILE_HOST_DEVICE static constexpr auto
321 calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
322 const LowVectorStrides& low_vector_strides)
323 {
324 // TODO: we allow pass through this vector length. If one need per-pixel check,
325 // should change the guaranteed vector length while creating the tensor view.
326 // It's up to runtime to check the padding length should be multiple of vector length
327 return make_tuple(low_vector_lengths, low_vector_strides);
328 }
329};
330
331template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
332CK_TILE_HOST_DEVICE static void
333print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
334{
335 printf("left_pad{");
336 printf("up_lengths_: ");
337 print(lp.up_lengths_);
338 printf(", left_pad_length_: ");
339 print(lp.left_pad_length_);
340 printf("}");
341}
342
343template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
344struct right_pad : public base_transform<1, 1>
345{
348
349 using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
350
352 LowLength low_length_;
353 RightPadLength right_pad_length_;
354
355 CK_TILE_HOST_DEVICE constexpr right_pad() = default;
356
357 CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length,
358 const RightPadLength& right_pad_length)
359 : up_lengths_{make_tuple(low_length + right_pad_length)},
360 low_length_{low_length},
361 right_pad_length_{right_pad_length}
362 {
363 }
364
365 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
366
367 template <typename LowIdx, typename UpIdx>
368 CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
369 const UpIdx& idx_up)
370 {
371 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
372 "wrong! inconsistent # of dimension");
373
374 idx_low(number<0>{}) = idx_up[number<0>{}];
375 }
376
377 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
378 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
379 const UpIdxDiff& idx_diff_up,
380 LowIdx& idx_low,
381 const UpIdx&)
382 {
383 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
384 UpIdx::size() == 1,
385 "wrong! inconsistent # of dimension");
386
387 constexpr auto I0 = number<0>{};
388
389 idx_diff_low[I0] = idx_diff_up[I0];
390
391 idx_low += idx_diff_low;
392 }
393
394 CK_TILE_HOST_DEVICE static constexpr bool
396 {
397 return SkipIsValidCheck;
398 }
399
400 template <typename UpIdx>
401 CK_TILE_HOST_DEVICE constexpr bool
403 {
404 return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_);
405 }
406
413
414 // MUST be static function
415 template <typename LowVectorLengths, typename LowVectorStrides>
416 CK_TILE_HOST_DEVICE static constexpr auto
417 calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
418 const LowVectorStrides& low_vector_strides)
419 {
420 // TODO: we allow pass through this vector length. If one need per-pixel check,
421 // should change the guaranteed vector length while creating the tensor view.
422 // It's up to runtime to check the padding length should be multiple of vector length
423 return make_tuple(low_vector_lengths, low_vector_strides);
424 }
425};
426
427template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
428CK_TILE_HOST_DEVICE static void
429print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
430{
431 printf("right_pad{");
432 printf("up_lengths_: ");
433 print(rp.up_lengths_);
434 printf(", right_pad_length_: ");
435 print(rp.right_pad_length_);
436 printf("}");
437}
438
439// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
440// UpLengths and Coefficients can be either of the followings:
441// 1) Tuple of index_t, which is known at run-time, or
442// 2) Tuple of number, which is known at compile-time, or
443// 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially
444// at compile-time
445template <typename UpLengths,
446 typename Coefficients,
447 typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
448struct embed : public base_transform<1, UpLengths::size()>
449{
450 static constexpr index_t NDimUp = UpLengths::size();
451
454
455 UpLengths up_lengths_;
456 Coefficients coefficients_;
457
458 CK_TILE_HOST_DEVICE constexpr embed() = default;
459
460 CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths,
461 const Coefficients& coefficients)
462 : up_lengths_{up_lengths}, coefficients_{coefficients}
463 {
464 }
465
466 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
467 {
469 }
470
471 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
472
473 template <typename LowIdx, typename UpIdx>
474 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
475 const UpIdx& idx_up) const
476 {
477 static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
478 "wrong! inconsistent # of dimension");
479
480 idx_low(number<0>{}) = 0;
481
482 static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
483 idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i];
484 });
485 }
486
487 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
488 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
489 const UpIdxDiff& idx_diff_up,
490 LowIdx& idx_low,
491 const UpIdx&) const
492 {
493 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
494 LowIdx::size() == 1 && UpIdx::size() == NDimUp,
495 "wrong! inconsistent # of dimension");
496
497 idx_diff_low(number<0>{}) = 0;
498
500 [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
501
502 idx_low += idx_diff_low;
503 }
504
505 CK_TILE_HOST_DEVICE static constexpr bool
510
511 template <typename UpIdx>
512 CK_TILE_HOST_DEVICE static constexpr bool
514 {
515 return true;
516 }
517
523};
524
525template <typename UpLengths, typename Coefficients>
526CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
527{
528 printf("embed{");
529 printf("up_lengths_: ");
530 print(e.up_lengths_);
531 printf(", coefficients_: ");
532 print(e.coefficients_);
533 printf("}");
534}
535
536template <typename LowLengths>
538{
539 template <index_t I>
541 {
542 return magic_division::calculate_magic_numbers(LowLengths{}[i]);
543 }
544};
545
546// Implementation of "merge" transformation primitive that uses magic-number-division to do lowering
547// of both multi-index and delta of multi-index
548// Caution:
549// 1. The magic number division implementation being used would produce correct result if the
550// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
551// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
552// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
553// uint32_t is then used.
554// 3. For merge primitive, upper-index is the dividend.
555// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
556// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
557// non-negative.
558template <typename LowLengths>
559struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
560{
561 static constexpr index_t NDimLow = LowLengths::size();
562
565
566 using UpLengths =
567 decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
568
571 number<NDimLow>{}));
572
573 LowLengths low_lengths_;
576
577 static constexpr auto I0 = number<0>{};
578 static constexpr auto I1 = number<1>{};
579
581
582 CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths)
583 : low_lengths_{low_lengths},
585 [&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
586 number<NDimLow>{})},
587 up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
588 {
589 static_assert(LowerIndex::size() == NDimLow, "wrong!");
590 }
591
592 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
593 {
595 }
596
597 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
598
599 template <typename LowIdx, typename UpIdx>
600 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
601 const UpIdx& idx_up) const
602 {
603 static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
604 "wrong! inconsistent # of dimension");
605
606 index_t tmp = idx_up[I0];
607
608 static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
609 index_t tmp2 =
613 idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
614 tmp = tmp2;
615 });
616
617 idx_low(number<0>{}) = tmp;
618 }
619
620 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
621 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
622 const UpIdxDiff&,
623 LowIdx& idx_low,
624 const UpIdx& idx_up_new) const
625 {
626 static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
627 LowIdx::size() == NDimLow && UpIdx::size() == 1,
628 "wrong! inconsistent # of dimension");
629
630 index_t tmp = idx_up_new[number<0>{}];
631
632 static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
633 index_t tmp2 =
637
638 index_t idx_low_old = idx_low[i];
639
640 idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
641 tmp = tmp2;
642
643 idx_diff_low(i) = idx_low[i] - idx_low_old;
644 });
645
646 idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{});
647
648 idx_low(number<0>{}) = tmp;
649 }
650
651 CK_TILE_HOST_DEVICE static constexpr bool
656
663
664 template <typename UpIdx>
665 CK_TILE_HOST_DEVICE static constexpr bool
667 {
668 return true;
669 }
670
671 // MUST be static function
672 template <typename LowVectorLengths, typename LowVectorStrides>
673 CK_TILE_HOST_DEVICE static constexpr auto
674 calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
675 const LowVectorStrides& low_vector_strides)
676 {
677 array<index_t, 1> up_vector_lengths{-1};
678 array<index_t, 1> up_vector_strides{-1};
679
680 up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
681 up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
682
683 return make_tuple(up_vector_lengths, up_vector_strides);
684 }
685};
686
687template <typename LowLengths>
688CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
689{
690 printf("merge_v2_magic_division{");
691 printf("low_lengths_: ");
692 print(m.low_lengths_);
693 printf(", up_lengths_: ");
694 print(m.up_lengths_);
695 printf("}");
696}
697
698// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
699// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
700// will be very bad
701template <typename LowLengths>
702struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
703{
704 static constexpr index_t NDimLow = LowLengths::size();
705
708
710 decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
711
712 using UpLengths =
713 decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
714
715 LowLengths low_lengths_;
718
720
721 CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
722 : low_lengths_{low_lengths},
725 up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
726 {
727 static_assert(LowerIndex::size() == NDimLow, "wrong!");
728 }
729
730 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
731
732 template <typename LowIdx, typename UpIdx>
733 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
734 const UpIdx& idx_up) const
735 {
736 static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
737 "wrong! inconsistent # of dimension");
738
739 index_t tmp = idx_up[number<0>{}];
740
741 // division and mod
742 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
743 idx_low(i) = tmp / this->low_lengths_scan_[i];
744 tmp %= this->low_lengths_scan_[i];
745 });
746
747 idx_low(number<NDimLow - 1>{}) = tmp;
748 }
749
750 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
751 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
752 const UpIdxDiff&,
753 LowIdx& idx_low,
754 const UpIdx& idx_up_new) const
755 {
756 static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
757 LowIdx::size() == NDimLow && UpIdx::size() == 1,
758 "wrong! inconsistent # of dimension");
759
760 constexpr auto I0 = number<0>{};
761 constexpr auto INm1 = number<NDimLow - 1>{};
762
763 index_t tmp = idx_up_new[I0];
764
765 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
766 const index_t tmp2 = idx_low[i];
767 idx_low(i) = tmp / this->low_lengths_scan_[i];
768 idx_diff_low(i) = idx_low[i] - tmp2;
769 tmp %= this->low_lengths_scan_[i];
770 });
771
772 const index_t tmp2 = idx_low[INm1];
773 idx_low(INm1) = tmp;
774 idx_diff_low(INm1) = idx_low[INm1] - tmp2;
775 }
776
777 CK_TILE_HOST_DEVICE static constexpr bool
782
789
790 template <typename UpIdx>
791 CK_TILE_HOST_DEVICE static constexpr bool
793 {
794 return true;
795 }
796
797 // MUST be static function
798 template <typename LowVectorLengths, typename LowVectorStrides>
799 CK_TILE_HOST_DEVICE static constexpr auto
800 calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
801 const LowVectorStrides& low_vector_strides)
802 {
803 array<index_t, 1> up_vector_lengths{-1};
804 array<index_t, 1> up_vector_strides{-1};
805
806 up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
807 up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
808
809 return make_tuple(up_vector_lengths, up_vector_strides);
810 }
811};
812
813template <typename LowLengths>
814CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
815{
816 printf("merge_v3_division_mod{");
817 printf("low_lengths_: ");
818 print(m.low_lengths_);
819 printf(", low_lengths_scan_: ");
820 print(m.low_lengths_scan_);
821 printf(", up_lengths_: ");
822 print(m.up_lengths_);
823 printf("}");
824}
825
826template <typename UpLengths, bool Use24BitIntegerCalculation>
827struct unmerge : public base_transform<1, UpLengths::size()>
828{
829 static constexpr index_t NDimUp = UpLengths::size();
830
833
835 decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
836
837 UpLengths up_lengths_;
839
840 CK_TILE_HOST_DEVICE constexpr unmerge() = default;
841
842 CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
843 : up_lengths_{up_lengths},
845 {
846 }
847
848 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
849 {
851 }
852
853 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
854
855 template <typename LowIdx, typename UpIdx>
856 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
857 const UpIdx& idx_up) const
858 {
859 if constexpr(!Use24BitIntegerCalculation)
860 {
861 idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
862
863 static_for<0, NDimUp - 1, 1>{}(
864 [&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
865 }
866 else
867 {
868 idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
869
870 static_for<0, NDimUp - 1, 1>{}([&](auto i) {
871 idx_low(number<0>{}) =
872 (0x00ffffff & idx_low[number<0>{}]) +
873 (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
874 });
875 }
876 }
877
878 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
879 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
880 const UpIdxDiff& idx_diff_up,
881 LowIdx& idx_low,
882 const UpIdx&) const
883 {
884 calculate_lower_index(idx_diff_low, idx_diff_up);
885
886 idx_low += idx_diff_low;
887 }
888
889 CK_TILE_HOST_DEVICE static constexpr bool
894
895 template <typename UpIdx>
896 CK_TILE_HOST_DEVICE static constexpr bool
898 {
899 return true;
900 }
901
907
908 // MUST be static function
909 template <typename LowVectorLengths, typename LowVectorStrides>
910 CK_TILE_HOST_DEVICE static constexpr auto
911 calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
912 const LowVectorStrides& low_vector_strides)
913 {
914 array<index_t, NDimUp> up_vector_lengths{-1};
915 array<index_t, NDimUp> up_vector_strides{-1};
916
917 constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];
918
919 if constexpr(ck_tile::is_known_at_compile_time<decltype(up_length_last)>::value)
920 {
921 if(low_vector_lengths[0] != -1)
922 {
923 up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last);
924 }
925 }
926
927 up_vector_strides(NDimUp - 1) = low_vector_strides[0];
928
929 return make_tuple(up_vector_lengths, up_vector_strides);
930 }
931};
932
933template <typename UpLengths, bool Use24BitIntegerCalculation>
935{
936 printf("unmerge{");
937 printf("up_lengths_: ");
938 print(u.up_lengths_);
939 printf(", up_lengths_scan_: ");
940 print(u.up_lengths_scan_);
941 printf("}");
942}
943
944template <typename LowerIndex>
945struct freeze : public base_transform<1, 0>
946{
947 LowerIndex low_idx_;
948
949 CK_TILE_HOST_DEVICE constexpr freeze() = default;
950
951 CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
952
953 CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; }
954
955 template <typename LowIdx, typename UpIdx>
956 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
957 const UpIdx& /* idx_up */) const
958 {
959 static_assert(LowIdx::size() == 1 && UpIdx::size() == 0,
960 "wrong! inconsistent # of dimension");
961
962 idx_low(number<0>{}) = low_idx_;
963 }
964
965 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
966 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
967 const UpIdxDiff& /* idx_diff_up */,
968 LowIdx& /* idx_low */,
969 const UpIdx& /* idx_up_new */)
970 {
971 idx_diff_low(number<0>{}) = 0;
972 }
973
974 CK_TILE_HOST_DEVICE static constexpr bool
979
980 template <typename UpIdx>
981 CK_TILE_HOST_DEVICE static constexpr bool
983 {
984 return true;
985 }
986
991};
992
993template <typename LowerIndex>
994CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
995{
996 printf("freeze{");
997 printf("low_idx_: ");
998 print(f.low_idx_);
999 printf("}");
1000}
1001
1002// insert a dangling upper dimension without lower dimension
1003template <typename UpperLength>
1004struct insert : public base_transform<0, 1>
1005{
1006 using UpLengths = decltype(make_tuple(UpperLength{}));
1007
1009
1010 CK_TILE_HOST_DEVICE constexpr insert() = default;
1011
1012 CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length)
1013 : up_lengths_{make_tuple(up_length)}
1014 {
1015 }
1016
1018
1020
1021 CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
1022
1023 template <typename LowIdx, typename UpIdx>
1024 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
1025 {
1026 static_assert(LowIdx::size() == 0 && UpIdx::size() == 1,
1027 "wrong! inconsistent # of dimension");
1028 }
1029
1030 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1031 CK_TILE_HOST_DEVICE static void
1032 update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
1033 {
1034 static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 &&
1035 UpIdx::size() == 1,
1036 "wrong! inconsistent # of dimension");
1037 }
1038
1039 CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; }
1040
1041 CK_TILE_HOST_DEVICE static constexpr bool
1046
1047 template <typename UpIdx>
1048 CK_TILE_HOST_DEVICE static constexpr bool
1050 {
1051 return true;
1052 }
1053
1058};
1059
1060template <typename UpperLength>
1061CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
1062{
1063 printf("insert{");
1064 printf("up_lengths_: ");
1065 print(i.up_lengths_);
1066 printf("}");
1067}
1068
1069// replicate the original tensor and create a higher dimensional tensor
1070template <typename UpLengths>
1071struct replicate : public base_transform<0, UpLengths::size()>
1072{
1073 static constexpr index_t NDimUp = UpLengths::size();
1074
1075 CK_TILE_HOST_DEVICE constexpr replicate() = default;
1076
1077 CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
1078 {
1079 }
1080
1081 CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
1082
1083 template <typename LowIdx, typename UpIdx>
1084 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
1085 {
1086 static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp,
1087 "wrong! inconsistent # of dimension");
1088 }
1089
1090 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1091 CK_TILE_HOST_DEVICE static void
1092 update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
1093 {
1094 static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp &&
1095 LowIdx::size() == 0 && UpIdx::size() == NDimUp,
1096 "wrong! inconsistent # of dimension");
1097 }
1098
1099 CK_TILE_HOST_DEVICE static constexpr bool
1104
1105 template <typename UpIdx>
1106 CK_TILE_HOST_DEVICE static constexpr bool
1108 {
1109 return true;
1110 }
1111
1116
1117 //
1118 UpLengths up_lengths_;
1119};
1120
1121template <typename UpLengths>
1122CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
1123{
1124 printf("replicate{");
1125 printf("up_lengths_: ");
1126 print(r.up_lengths_);
1127 printf("}");
1128}
1129
1130template <typename LowLength, typename SliceBegin, typename SliceEnd>
1131struct slice : public base_transform<1, 1>
1132{
1135
1136 using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
1137
1139 SliceBegin slice_begin_;
1140 SliceEnd slice_end_;
1141
1142 CK_TILE_HOST_DEVICE constexpr slice() = default;
1143
1144 CK_TILE_HOST_DEVICE constexpr slice(const LowLength&,
1145 const SliceBegin& slice_begin,
1146 const SliceEnd& slice_end)
1147 : up_lengths_{make_tuple(slice_end - slice_begin)},
1148 slice_begin_{slice_begin},
1149 slice_end_{slice_end}
1150 {
1151 }
1152
1153 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1154
1155 template <typename LowIdx, typename UpIdx>
1156 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1157 const UpIdx& idx_up) const
1158 {
1159 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1160 "wrong! inconsistent # of dimension");
1161
1162 idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_;
1163 }
1164
1165 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1166 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1167 const UpIdxDiff& idx_diff_up,
1168 LowIdx& idx_low,
1169 const UpIdx&)
1170 {
1171 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1172 UpIdx::size() == 1,
1173 "wrong! inconsistent # of dimension");
1174
1175 constexpr auto I0 = number<0>{};
1176
1177 idx_diff_low[I0] = idx_diff_up[I0];
1178
1179 idx_low += idx_diff_low;
1180 }
1181
1182 CK_TILE_HOST_DEVICE static constexpr bool
1187
1188 template <typename UpIdx>
1189 CK_TILE_HOST_DEVICE constexpr bool
1191 {
1192 return true;
1193 }
1194
1201};
1202
1203template <typename LowLength, typename SliceBegin, typename SliceEnd>
1205{
1206 printf("slice{");
1207 printf("up_lengths_: ");
1208 print(s.up_lengths_);
1209 printf(", slice_begin_: ");
1210 print(s.slice_begin_);
1211 printf(", slice_end_: ");
1212 print(s.slice_end_);
1213 printf("}");
1214}
1215
1216/*
1217 * \brief lower_idx = upper_idx % modulus.
1218 * TODO: Need an improved implementation since the modulo operation is expensive.
1219 */
1220template <typename Modulus, typename UpLength>
1221struct modulo : public base_transform<1, 1>
1222{
1225 using UpLengths = decltype(make_tuple(UpLength{}));
1226
1227 Modulus modulus_;
1229
1230 CK_TILE_HOST_DEVICE constexpr modulo() = default;
1231
1232 CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length)
1233 : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
1234 {
1235 }
1236
1237 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1238
1239 template <typename LowIdx, typename UpIdx>
1240 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1241 const UpIdx& idx_up) const
1242 {
1243 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1244 "wrong! inconsistent # of dimension");
1245
1246 idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_;
1247 }
1248
1249 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1250 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1251 const UpIdxDiff& idx_diff_up,
1252 LowIdx& idx_low,
1253 const UpIdx& up_idx) const
1254 {
1255 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1256 UpIdx::size() == 1,
1257 "wrong! inconsistent # of dimension");
1258
1259 constexpr auto I0 = number<0>{};
1260
1261 const auto idx_low_old = idx_low;
1262 idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_;
1263 idx_diff_low[I0] = idx_low - idx_low_old;
1264 }
1265
1266 CK_TILE_HOST_DEVICE static constexpr bool
1271
1272 template <typename UpIdx>
1273 CK_TILE_HOST_DEVICE static constexpr bool
1275 {
1276 return true;
1277 }
1278
1283};
1284
1285template <typename Modulus, typename UpLength>
1286CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
1287{
1288 printf("modulo{");
1289 printf("modulus_: ");
1290 print(m.modulus_);
1291 printf(", up_lengths_: ");
1292 print(m.up_lengths_);
1293 printf("}");
1294}
1295
1296// 2D XOR, NOTE: "xor" is a keyword
1297template <typename LowLengths>
1298struct xor_t : public base_transform<2, 2>
1299{
1301
1304
1305 using UpLengths = LowLengths;
1306
1308
1310
1311 CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
1312
1313 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1314 {
1316 }
1317
1318 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1319
1320 template <typename LowIdx, typename UpIdx>
1321 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1322 const UpIdx& idx_up) const
1323 {
1324 static_assert(LowIdx::size() == 2 && UpIdx::size() == 2,
1325 "wrong! inconsistent # of dimension");
1326
1327 idx_low(number<0>{}) = idx_up[number<0>{}];
1328
1329 idx_low(number<1>{}) =
1330 idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
1331 }
1332
1333 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1334 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1335 const UpIdxDiff&,
1336 LowIdx& idx_low,
1337 const UpIdx& idx_up) const
1338 {
1339 static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 &&
1340 UpIdx::size() == 2,
1341 "wrong! inconsistent # of dimension");
1342
1343 const auto idx_low_old = idx_low;
1344
1345 calculate_lower_index(idx_low, idx_up);
1346
1347 idx_diff_low = idx_low - idx_low_old;
1348 }
1349
1350 CK_TILE_HOST_DEVICE static constexpr bool
1355
1356 template <typename UpIdx>
1357 CK_TILE_HOST_DEVICE static constexpr bool
1359 {
1360 return true;
1361 }
1362
1367
1368 // MUST be static function
1369 template <typename LowVectorLengths, typename LowVectorStrides>
1371 const LowVectorLengths& low_vector_lengths,
1372 const LowVectorStrides& low_vector_strides) const
1373 {
1374 array<index_t, 2> up_vector_lengths = low_vector_lengths;
1375 array<index_t, 2> up_vector_strides = low_vector_strides;
1376
1377 return make_tuple(up_vector_lengths, up_vector_strides);
1378 }
1379};
1380
1381template <typename LowLengths>
1382CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
1383{
1384 printf("xor_t{");
1385 printf("up_lengths_: ");
1386 print(x.up_lengths_);
1387 printf("}");
1388}
1389
1390template <typename LowLength, typename OffsetLength>
1391struct offset : public base_transform<1, 1>
1392{
1395
1396 using UpLengths = decltype(make_tuple(LowLength{}));
1397
1399 OffsetLength offset_length_;
1400
1401 CK_TILE_HOST_DEVICE constexpr offset() = default;
1402
1403 CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length,
1404 const OffsetLength& offset_length)
1405 : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length}
1406 {
1407 }
1408
1409 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1410 {
1412 }
1413
1414 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1415
1416 template <typename LowIdx, typename UpIdx>
1417 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1418 const UpIdx& idx_up) const
1419 {
1420 static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1421 "wrong! inconsistent # of dimension");
1422
1423 idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_;
1424 }
1425
1426 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1427 CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1428 const UpIdxDiff& idx_diff_up,
1429 LowIdx& idx_low,
1430 const UpIdx&)
1431 {
1432 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1433 UpIdx::size() == 1,
1434 "wrong! inconsistent # of dimension");
1435
1436 constexpr auto I0 = number<0>{};
1437
1438 idx_diff_low[I0] = idx_diff_up[I0];
1439
1440 idx_low += idx_diff_low;
1441 }
1442
1443 CK_TILE_HOST_DEVICE static constexpr bool
1448
1449 template <typename UpIdx>
1450 CK_TILE_HOST_DEVICE constexpr bool
1452 {
1453 return true;
1454 }
1455
1461};
1462
1463template <typename LowLength, typename OffsetLength>
1464CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
1465{
1466 printf("offset{");
1467 printf("up_lengths_: ");
1468 print(o.up_lengths_);
1469 printf(", offset_length_: ");
1470 print(o.offset_length_);
1471 printf("}");
1472}
1473
1474template <typename UpLength, typename IndexingAdaptor>
1475struct indexing : public base_transform<1, 1>
1476{
1477 static constexpr index_t NDimUp = 1;
1478
1481
1482 using UpLengths = decltype(make_tuple(UpLength{}));
1484 IndexingAdaptor iadaptor_;
1485
1486 CK_TILE_HOST_DEVICE constexpr indexing() = default;
1487
1488 CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
1489 const IndexingAdaptor& iadaptor)
1490 : up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
1491 {
1492 }
1493
1494 CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1495 {
1497 }
1498
1499 CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1500
1501 template <typename LowIdx, typename UpIdx>
1502 CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1503 const UpIdx& idx_up) const
1504 {
1505 static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
1506 "wrong! inconsistent # of dimension");
1507 iadaptor_.calculate_lower_index(idx_low, idx_up);
1508 }
1509
1510 template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1511 CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1512 const UpIdxDiff& idx_diff_up,
1513 LowIdx& idx_low,
1514 const UpIdx& idx_up) const
1515 {
1516 // TODO: nonthing changed here
1517 static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
1518 LowIdx::size() == 1 && UpIdx::size() == NDimUp,
1519 "wrong! inconsistent # of dimension");
1520
1521 iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
1522 }
1523
1524 CK_TILE_HOST_DEVICE static constexpr bool
1529
1530 template <typename UpIdx>
1531 CK_TILE_HOST_DEVICE static constexpr bool
1533 {
1534 return true;
1535 }
1536
1538 {
1540 IndexingAdaptor::is_known_at_compile_time();
1541 }
1542};
1543
1544template <typename UpLength, typename IndexingAdaptor>
1545CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
1546{
1547 printf("indexing{");
1548 printf("up_lengths_: ");
1549 print(i.up_lengths_);
1550 printf(", iadaptor_: ");
1551 print(i.iadaptor_);
1552 printf("}");
1553}
1554
1555//*******************************************************************************************************
1556
1557template <typename LowLength>
1558CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length)
1559{
1560 return pass_through<LowLength>{low_length};
1561}
1562
1563template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
1564CK_TILE_HOST_DEVICE constexpr auto
1565make_pad_transform(const LowLength& low_length,
1566 const LeftPad& left_pad,
1567 const RightPad& right_pad,
1569{
1570 return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
1571}
1572
1573template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
1574CK_TILE_HOST_DEVICE constexpr auto
1575make_left_pad_transform(const LowLength& low_length,
1576 const LeftPadLength& left_pad_,
1578{
1579 return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
1580}
1581
1582template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
1583CK_TILE_HOST_DEVICE constexpr auto
1584make_right_pad_transform(const LowLength& low_length,
1585 const RightPadLength& right_pad_,
1587{
1588 return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
1589}
1590
1591template <typename UpLengths,
1592 typename Coefficients,
1593 typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
1594CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths,
1595 const Coefficients& coefficients)
1596{
1597 return embed<UpLengths, Coefficients>{up_lengths, coefficients};
1598}
1599
1600template <typename LowLengths>
1601CK_TILE_HOST_DEVICE constexpr auto
1602make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
1603{
1604 return merge_v2_magic_division<LowLengths>{low_lengths};
1605}
1606
1607template <typename LowLengths>
1608CK_TILE_HOST_DEVICE constexpr auto
1609make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
1610{
1611 return merge_v3_division_mod<LowLengths>{low_lengths};
1612}
1613
1614template <typename LowLengths>
1615CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths)
1616{
1617 return make_merge_transform_v2_magic_division(low_lengths);
1618}
1619
1620template <typename UpLengths, bool Use24BitIntegerCalculation = false>
1621CK_TILE_HOST_DEVICE constexpr auto
1627
1628template <typename LowerIndex>
1629CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx)
1630{
1631 return freeze<LowerIndex>{low_idx};
1632}
1633
1634template <typename UpperIndex>
1635CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx)
1636{
1637 return insert<UpperIndex>{up_idx};
1638}
1639
1640template <typename UpLengths>
1641CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths)
1642{
1643 return replicate<UpLengths>{up_lengths};
1644}
1645
1646template <typename LowLength, typename SliceBegin, typename SliceEnd>
1647CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length,
1648 const SliceBegin& slice_begin,
1649 const SliceEnd& slice_end)
1650{
1651 return slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
1652}
1653
1654template <typename Modulus, typename UpLength>
1655CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
1656 const UpLength& up_length)
1657{
1658 return modulo<Modulus, UpLength>{modulus, up_length};
1659}
1660
1661template <typename LowLengths>
1662CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
1663{
1664 return xor_t<LowLengths>{low_lengths};
1665}
1666
1667template <typename LowLength, typename OffsetLength>
1668CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length,
1669 const OffsetLength& offset_length)
1670{
1671 return offset<LowLength, OffsetLength>{low_length, offset_length};
1672}
1673
1674} // namespace ck_tile
1675
1677namespace ck_tile {
1678
1679template <typename UpLength, typename Indices>
1680CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
1681 const Indices& indices)
1682{
1683 // by default we use the simplest one
1686}
1687
1688template <typename UpLength, typename IndexingAdaptor>
1689CK_TILE_HOST_DEVICE constexpr auto
1690make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
1691{
1692 return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
1693}
1694
1695} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
__host__ __device__ constexpr auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition layout_utils.hpp:474
Definition tile/core/algorithm/cluster_descriptor.hpp:13
coord_transform_enum
Definition coordinate_transform.hpp:17
@ embed
Definition coordinate_transform.hpp:21
@ unmerge
Definition coordinate_transform.hpp:23
@ undefined
Definition coordinate_transform.hpp:18
@ merge
Definition coordinate_transform.hpp:22
@ xor_t
Definition coordinate_transform.hpp:25
@ offset
Definition coordinate_transform.hpp:26
@ indexing
Definition coordinate_transform.hpp:27
@ pass_through
Definition coordinate_transform.hpp:19
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr auto make_left_pad_transform(const LowLength &low_length, const LeftPadLength &left_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1575
CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths &up_lengths)
Definition coordinate_transform.hpp:1641
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition coordinate_transform.hpp:1629
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex &up_idx)
Definition coordinate_transform.hpp:1635
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform_with_adaptor(const UpLength &up_lengths, const IndexingAdaptor &iadaptor)
Definition coordinate_transform.hpp:1690
CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus &modulus, const UpLength &up_length)
Definition coordinate_transform.hpp:1655
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
Definition tile/core/numeric/math.hpp:268
CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength &low_length, const OffsetLength &offset_length)
Definition coordinate_transform.hpp:1668
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v2_magic_division(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1602
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
is_static< T > is_known_at_compile_time
Definition type_traits.hpp:94
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition coordinate_transform.hpp:1647
CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1584
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition coordinate_transform.hpp:1680
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan(const array< TData, NSize > &x, Reduce f, Init init)
Definition tile/core/container/container_helper.hpp:240
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition coordinate_transform.hpp:32
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_lower_dimension()
Definition coordinate_transform.hpp:38
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:33
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &, const LowVectorStrides &)
Definition coordinate_transform.hpp:47
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_upper_dimension()
Definition coordinate_transform.hpp:40
Definition coordinate_transform.hpp:449
CK_TILE_HOST_DEVICE constexpr embed(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:460
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:466
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:452
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:513
multi_index< NDimUp > UpperIndex
Definition coordinate_transform.hpp:453
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:474
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:518
CK_TILE_HOST_DEVICE constexpr embed()=default
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &) const
Definition coordinate_transform.hpp:488
UpLengths up_lengths_
Definition coordinate_transform.hpp:455
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:471
Coefficients coefficients_
Definition coordinate_transform.hpp:456
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:506
static constexpr index_t NDimUp
Definition coordinate_transform.hpp:450
Definition coordinate_transform.hpp:946
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:982
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:987
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition coordinate_transform.hpp:966
static CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths()
Definition coordinate_transform.hpp:953
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:975
LowerIndex low_idx_
Definition coordinate_transform.hpp:947
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &) const
Definition coordinate_transform.hpp:956
CK_TILE_HOST_DEVICE constexpr freeze()=default
CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex &low_idx)
Definition coordinate_transform.hpp:951
static constexpr bool value
Definition type_traits.hpp:77
Definition indexing_adaptor.hpp:20
Definition coordinate_transform.hpp:1476
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:1532
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1537
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1502
CK_TILE_HOST_DEVICE constexpr indexing()=default
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:1494
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength &up_length, const IndexingAdaptor &iadaptor)
Definition coordinate_transform.hpp:1488
decltype(make_tuple(UpLength{})) UpLengths
Definition coordinate_transform.hpp:1482
UpLengths up_lengths_
Definition coordinate_transform.hpp:1483
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1511
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:1479
IndexingAdaptor iadaptor_
Definition coordinate_transform.hpp:1484
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:1499
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1525
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:1480
static constexpr index_t NDimUp
Definition coordinate_transform.hpp:1477
Definition coordinate_transform.hpp:1005
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition coordinate_transform.hpp:1032
CK_TILE_HOST_DEVICE constexpr insert(const UpperLength &up_length)
Definition coordinate_transform.hpp:1012
UpLengths up_lengths_
Definition coordinate_transform.hpp:1008
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1054
decltype(make_tuple(UpperLength{})) UpLengths
Definition coordinate_transform.hpp:1006
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:1049
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1042
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &, const UpIdx &) const
Definition coordinate_transform.hpp:1024
CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const
Definition coordinate_transform.hpp:1021
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_lower_dimension()
Definition coordinate_transform.hpp:1017
static CK_TILE_HOST_DEVICE constexpr bool IsLinearTransform()
Definition coordinate_transform.hpp:1039
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_upper_dimension()
Definition coordinate_transform.hpp:1019
CK_TILE_HOST_DEVICE constexpr insert()=default
CK_TILE_HOST_DEVICE constexpr auto operator()(number< I > i) const
Definition coordinate_transform.hpp:540
Definition coordinate_transform.hpp:253
CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition coordinate_transform.hpp:307
CK_TILE_HOST_DEVICE constexpr left_pad()=default
LeftPadLength left_pad_length_
Definition coordinate_transform.hpp:260
CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength &low_length, const LeftPadLength &left_pad_length)
Definition coordinate_transform.hpp:264
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:273
UpLengths up_lengths_
Definition coordinate_transform.hpp:259
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:300
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:270
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:254
decltype(make_tuple(LowLength{}+LeftPadLength{})) UpLengths
Definition coordinate_transform.hpp:257
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition coordinate_transform.hpp:321
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition coordinate_transform.hpp:283
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:312
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:255
static CK_TILE_DEVICE constexpr uint32_t do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
Definition magic_div.hpp:60
static CK_TILE_HOST_DEVICE constexpr auto calculate_magic_numbers(uint32_t divisor)
Definition magic_div.hpp:29
Definition coordinate_transform.hpp:560
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:666
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new) const
Definition coordinate_transform.hpp:621
LowLengthsMagicDivisor low_lengths_magic_divisor_
Definition coordinate_transform.hpp:574
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number< 1 >{}))) UpLengths
Definition coordinate_transform.hpp:566
static constexpr auto I1
Definition coordinate_transform.hpp:578
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:592
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:652
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition coordinate_transform.hpp:674
static constexpr index_t NDimLow
Definition coordinate_transform.hpp:561
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:600
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:564
LowLengths low_lengths_
Definition coordinate_transform.hpp:573
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:657
CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division()=default
static constexpr auto I0
Definition coordinate_transform.hpp:577
UpLengths up_lengths_
Definition coordinate_transform.hpp:575
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:597
CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:582
multi_index< NDimLow > LowerIndex
Definition coordinate_transform.hpp:563
decltype(generate_tuple( lambda_merge_generate_MagicDivision_calculate_magic_divisor< LowLengths >{}, number< NDimLow >{})) LowLengthsMagicDivisor
Definition coordinate_transform.hpp:569
Definition coordinate_transform.hpp:703
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition coordinate_transform.hpp:800
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod()=default
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:778
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:707
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:730
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number< 1 >{}))) UpLengths
Definition coordinate_transform.hpp:712
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:783
LowLengths low_lengths_
Definition coordinate_transform.hpp:715
UpLengths up_lengths_
Definition coordinate_transform.hpp:717
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:792
static constexpr index_t NDimLow
Definition coordinate_transform.hpp:704
LowLengthsScan low_lengths_scan_
Definition coordinate_transform.hpp:716
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number< 1 >{})) LowLengthsScan
Definition coordinate_transform.hpp:709
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:721
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:733
multi_index< NDimLow > LowerIndex
Definition coordinate_transform.hpp:706
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new) const
Definition coordinate_transform.hpp:751
Definition coordinate_transform.hpp:1222
UpLengths up_lengths_
Definition coordinate_transform.hpp:1228
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:1224
decltype(make_tuple(UpLength{})) UpLengths
Definition coordinate_transform.hpp:1225
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &up_idx) const
Definition coordinate_transform.hpp:1250
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1279
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1267
Modulus modulus_
Definition coordinate_transform.hpp:1227
CK_TILE_HOST_DEVICE constexpr modulo(const Modulus &modulus, const UpLength &up_length)
Definition coordinate_transform.hpp:1232
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:1237
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1240
CK_TILE_HOST_DEVICE constexpr modulo()=default
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:1223
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:1274
Definition tile/core/numeric/math.hpp:98
Definition coordinate_transform.hpp:1392
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition coordinate_transform.hpp:1427
CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &) const
Definition coordinate_transform.hpp:1451
decltype(make_tuple(LowLength{})) UpLengths
Definition coordinate_transform.hpp:1396
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1444
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1456
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:1394
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:1409
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:1393
OffsetLength offset_length_
Definition coordinate_transform.hpp:1399
CK_TILE_HOST_DEVICE constexpr offset(const LowLength &low_length, const OffsetLength &offset_length)
Definition coordinate_transform.hpp:1403
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:1414
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1417
UpLengths up_lengths_
Definition coordinate_transform.hpp:1398
CK_TILE_HOST_DEVICE constexpr offset()=default
Definition coordinate_transform.hpp:161
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:163
decltype(make_tuple(LowLength{}+LeftPadLength{}+RightPadLength{})) UpLengths
Definition coordinate_transform.hpp:165
LeftPadLength left_pad_length_
Definition coordinate_transform.hpp:168
UpLengths up_lengths_
Definition coordinate_transform.hpp:167
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition coordinate_transform.hpp:195
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:226
CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition coordinate_transform.hpp:219
CK_TILE_HOST_DEVICE constexpr pad()
Definition coordinate_transform.hpp:171
RightPadLength right_pad_length_
Definition coordinate_transform.hpp:169
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:185
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:162
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:182
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:212
CK_TILE_HOST_DEVICE constexpr pad(const LowLength &low_length, const LeftPadLength &left_pad_length, const RightPadLength &right_pad_length)
Definition coordinate_transform.hpp:173
Definition coordinate_transform.hpp:66
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:88
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition coordinate_transform.hpp:138
UpLengths up_lengths_
Definition coordinate_transform.hpp:74
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:125
decltype(make_tuple(LowLength{})) UpLengths
Definition coordinate_transform.hpp:72
static CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up)
Definition coordinate_transform.hpp:91
static constexpr auto type_enum
Definition coordinate_transform.hpp:67
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition coordinate_transform.hpp:101
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:130
CK_TILE_HOST_DEVICE constexpr pass_through()=default
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:70
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:118
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:83
CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength &low_length)
Definition coordinate_transform.hpp:78
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:69
Definition coordinate_transform.hpp:1072
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1100
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1112
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &, const UpIdx &) const
Definition coordinate_transform.hpp:1084
CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const
Definition coordinate_transform.hpp:1081
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition coordinate_transform.hpp:1092
CK_TILE_HOST_DEVICE constexpr replicate()=default
UpLengths up_lengths_
Definition coordinate_transform.hpp:1118
static constexpr index_t NDimUp
Definition coordinate_transform.hpp:1073
CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths &up_lengths)
Definition coordinate_transform.hpp:1077
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:1107
Definition coordinate_transform.hpp:345
LowLength low_length_
Definition coordinate_transform.hpp:352
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition coordinate_transform.hpp:378
CK_TILE_HOST_DEVICE constexpr right_pad()=default
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:407
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:347
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:365
RightPadLength right_pad_length_
Definition coordinate_transform.hpp:353
CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition coordinate_transform.hpp:402
CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength &low_length, const RightPadLength &right_pad_length)
Definition coordinate_transform.hpp:357
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition coordinate_transform.hpp:417
static CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up)
Definition coordinate_transform.hpp:368
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:395
decltype(make_tuple(LowLength{}+RightPadLength{})) UpLengths
Definition coordinate_transform.hpp:349
UpLengths up_lengths_
Definition coordinate_transform.hpp:351
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:346
Definition coordinate_transform.hpp:1132
CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &) const
Definition coordinate_transform.hpp:1190
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:1153
multi_index< 1 > UpperIndex
Definition coordinate_transform.hpp:1134
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition coordinate_transform.hpp:1166
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1156
SliceBegin slice_begin_
Definition coordinate_transform.hpp:1139
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:1133
UpLengths up_lengths_
Definition coordinate_transform.hpp:1138
decltype(make_tuple(SliceEnd{} - SliceBegin{})) UpLengths
Definition coordinate_transform.hpp:1136
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1195
CK_TILE_HOST_DEVICE constexpr slice()=default
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1183
SliceEnd slice_end_
Definition coordinate_transform.hpp:1140
CK_TILE_HOST_DEVICE constexpr slice(const LowLength &, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition coordinate_transform.hpp:1144
Definition tile/core/utility/functional.hpp:43
Definition coordinate_transform.hpp:828
multi_index< NDimUp > UpperIndex
Definition coordinate_transform.hpp:832
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:853
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &) const
Definition coordinate_transform.hpp:879
CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths &up_lengths)
Definition coordinate_transform.hpp:842
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:890
CK_TILE_HOST_DEVICE constexpr unmerge()=default
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number< 1 >{})) UpLengthsScan
Definition coordinate_transform.hpp:834
UpLengthsScan up_lengths_scan_
Definition coordinate_transform.hpp:838
static CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition coordinate_transform.hpp:911
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:848
static constexpr index_t NDimUp
Definition coordinate_transform.hpp:829
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:856
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:902
UpLengths up_lengths_
Definition coordinate_transform.hpp:837
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:897
multi_index< 1 > LowerIndex
Definition coordinate_transform.hpp:831
Definition coordinate_transform.hpp:1299
static CK_TILE_HOST_DEVICE constexpr auto get_type_enum()
Definition coordinate_transform.hpp:1313
multi_index< 2 > LowerIndex
Definition coordinate_transform.hpp:1302
CK_TILE_HOST_DEVICE constexpr xor_t()
Definition coordinate_transform.hpp:1309
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1311
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition coordinate_transform.hpp:1358
LowLengths UpLengths
Definition coordinate_transform.hpp:1305
static constexpr auto type_enum
Definition coordinate_transform.hpp:1300
CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides) const
Definition coordinate_transform.hpp:1370
UpLengths up_lengths_
Definition coordinate_transform.hpp:1307
CK_TILE_HOST_DEVICE constexpr const auto & get_upper_lengths() const
Definition coordinate_transform.hpp:1318
multi_index< 2 > UpperIndex
Definition coordinate_transform.hpp:1303
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition coordinate_transform.hpp:1363
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1334
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition coordinate_transform.hpp:1321
static CK_TILE_HOST_DEVICE constexpr bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition coordinate_transform.hpp:1351
constexpr auto slice(const FromType from, const ToType to)
Get dim slice.
Definition tensor_utils.hpp:245