tile_scatter_gather.hpp Source File

tile_scatter_gather.hpp Source File#

Composable Kernel: tile_scatter_gather.hpp Source File
tile_scatter_gather.hpp
Go to the documentation of this file.
1
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6
19
20namespace ck_tile {
21
33template <typename BottomTensorView_,
34 typename WindowLengths_,
35 typename StaticTileDistribution_,
36 typename StaticPageIndexArray_,
37 typename StaticValidArray_,
38 index_t HsGatherDim = 0,
39 index_t NumCoord = 1,
40 index_t YsGatherDim = 0>
42{
48 using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
49 using BottomTensorDesc = typename BottomTensorView::TensorDesc;
50
52
53 static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
54 static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
55
56 static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
57 static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
58
59 static constexpr auto I0 = number<0>{};
60 static constexpr auto I1 = number<1>{};
61 static_assert(NumCoord == 1);
62
63 // TODO: check WindowLengths and StaticTileDistribution are consistent
64
66 "wrong! lengths should be static");
67 static_assert(TileDstr::is_static(), "wrong!");
68
69 static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
70 "wrong! inconsistent # of diemsnions");
71
74
77
80
82 {
83 private:
84 static constexpr auto get_vector_dim_y_scalar_per_vector()
85 {
86 const auto [ys_vector_lengths, ys_vector_strides] =
88
89 index_t VectorDimY_ = 0;
90 index_t ScalarPerVector_ = 1;
91
92 for(index_t i = 0; i < NDimY; ++i)
93 {
94 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
95 {
96 ScalarPerVector_ = ys_vector_lengths[i];
97 VectorDimY_ = i;
98 }
99 }
100
101 return make_tuple(VectorDimY_, ScalarPerVector_);
102 }
103
104 public:
105 static constexpr index_t PackedSize =
107 static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
108 static constexpr index_t ScalarPerVector =
109 get_vector_dim_y_scalar_per_vector().template at<1>();
110
111 // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
112 // using vector_t = typename vector_type_t::type;
114
115 private:
116 static constexpr auto scalars_per_access_ = [] {
117 constexpr auto scalars_per_access_arr = generate_array(
118 [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
119
121 constexpr auto NDimY_ = NDimY;
122
123 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
124 }();
125
126 static constexpr auto get_space_filling_curve()
127 {
128 constexpr auto tile_dstr = TileDstr{};
129
130 constexpr auto thread_tensor_lengths_ys =
131 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
132
133 // FIXME: need logic to judge dim access order
134 using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
135
136 return space_filling_curve<decltype(thread_tensor_lengths_ys),
137 DimAccessOrder,
138 decltype(scalars_per_access_)>{};
139 }
140
141 public:
142 using SFC_Ys = decltype(get_space_filling_curve());
143
144 static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
145
146 static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
147 static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
148 };
149
151
152 CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
153
154 CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
155 const WindowLengths& window_lengths,
156 const BottomTensorIndex& window_origin,
158 const PageIdxArray& page_idx,
159 const ValidArray& valids)
160 : bottom_tensor_view_{bottom_tensor_view},
161 window_lengths_{window_lengths},
162 window_origin_{window_origin},
164 page_idx_{page_idx},
165 valids_{valids},
167 {
168#if 0 // debug
169 // TODO: this use more register for FA, but less register for GEMM
170 // need investigation
171 // only support warp-tile and block-tile
172 static_assert(NDimP == 1 or NDimP == 2, "wrong!");
173
174 WindowAdaptorCoord window_adaptor_thread_coord_tmp;
175
176 if constexpr(NDimP == 1)
177 {
178 window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
180 }
181 else if constexpr(NDimP == 2)
182 {
183 window_adaptor_thread_coord_tmp =
185 AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
186 }
187#else
188 // TODO: this use less register for FA, but more register for GEMM
189 // need investigation
190 const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
194#endif
195
196 BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
197 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
198 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
199 const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
200 bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
201
202 // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
203 // future load/store() calls (might allocate more registers)
204 using Traits = load_store_traits;
205 using SFC_Ys = typename Traits::SFC_Ys;
206
207 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
208 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
209 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
210
211 constexpr auto idx_diff_ys =
212 SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
213
214 constexpr auto idx_diff_ps_ys = container_concat(
215 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
216
218 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
219
220 pre_computed_coords_(iCoord) =
221 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
222 });
223 }
224
226
228 {
229 return TileDstr::is_static();
230 }
231
232 CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
233
234 CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
235
237
238 CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
239
240 CK_TILE_DEVICE constexpr void
241 set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
242 {
243 bottom_tensor_view_.buf_.p_data_ = data;
244 }
245
246 // move thread's window adaptor coordinate and bottom tensor coordinate
247 // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
248 template <typename ATopIndex>
250 WindowAdaptorCoord& window_adaptor_thread_coord,
251 BottomTensorCoord& bottom_tensor_thread_coord,
252 const ATopIndex& idx_diff_adaptor_top) const
253 {
254 array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
255
256 move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
257 window_adaptor_thread_coord,
258 idx_diff_adaptor_top,
259 idx_diff_adaptor_bottom);
260
261 move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
262 bottom_tensor_thread_coord,
263 idx_diff_adaptor_bottom);
264 }
265
266 // return vector dimension among [y0, y1, ...]
268 {
269 // bottom tensor top dimension vector lengths and strides
270 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
271 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
272
273 // window vector lengths/strides
274 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
275 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
276
277 // window adaptor [p0, p1, ..., y0, y1, ...]
278 array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
279 -1};
280 array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
281 -1};
282
283 constexpr auto window_adaptor_bottom_dims =
284 WindowAdaptor::get_bottom_dimension_hidden_ids();
285
286 set_container_subset(window_adaptor_vector_lengths,
287 window_adaptor_bottom_dims,
288 window_adaptor_bottom_dim_vector_lengths);
289 set_container_subset(window_adaptor_vector_strides,
290 window_adaptor_bottom_dims,
291 window_adaptor_bottom_dim_vector_strides);
292
293 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
294 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
295 window_adaptor_vector_lengths, window_adaptor_vector_strides);
296
297 // [y0, y1, ...]
298 constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
300 1>::type{};
301
302 return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
303 get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
304 }
305
307
308 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
311 {
312 constexpr auto tile_dstr = TileDstr{};
313 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
315 return dst_tensor;
316 }
317
318 template <typename DistributedTensor,
319 index_t i_access_unsupport_ = -1,
320 bool oob_conditional_check = true>
321 CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
324 {
325 using Traits = load_store_traits;
326 using vector_t = typename Traits::vector_t;
327 using SFC_Ys = typename Traits::SFC_Ys;
328
329 constexpr auto tile_dstr = TileDstr{};
330
331 // loop over thread tensor space [y0, y1, ...]
332 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
334 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
335 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
336
337 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
338 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
339
340 // data index [y0, y1, ...]
341 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
342 constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
343 const auto page_offset = page_idx_[idx_gather];
344
345 // read from bottom tensor
346 const vector_t vec_value = [&]() {
347 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
348 {
349 return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
350 bottom_tensor_thread_coord,
351 page_offset,
353 }
354 else
355 {
356 return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
357 bottom_tensor_thread_coord,
358 page_offset,
359 valids_[idx_gather],
361 }
362 }();
363#if 1
364 // write into distributed tensor
365 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
366 constexpr auto idx_ys = generate_tuple(
367 [&](auto jj) {
368 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
369 : idx_ys_start[jj];
370 },
371 number<NDimY>{});
372
373 constexpr index_t d =
374 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
375 Traits::PackedSize;
376
377 dst_tensor.get_thread_buffer().template at<d>() =
378 vec_value.template get_as<DataType>()[j / Traits::PackedSize];
379 });
380#else
381 constexpr index_t d =
382 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
383 static_assert(d % Traits::ScalarPerVector == 0);
384
385 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
386 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
387#endif
388 // move thread coordinate
389 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
390 {
391 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
392
393 constexpr auto forward_step_scatter = generate_tuple(
394 [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
395 number<NDimY>{});
396
397 constexpr auto idx_diff_ps_ys = container_concat(
398 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
399 forward_step_scatter);
400
402 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
403 }
404 });
405 });
406 }
407
408 template <typename LdsTileWindow_,
409 index_t i_access_unsupport_ = -1,
410 bool oob_conditional_check = true>
411 CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
414 {
415 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
416 using LdsDataType = typename LdsTileWindow::DataType;
417 using Traits = load_store_traits;
418 using vector_t = typename Traits::vector_t;
419 using SFC_Ys = typename Traits::SFC_Ys;
420
421 constexpr auto tile_dstr = TileDstr{};
422
423 // Precompute invariant values outside loops
424 const auto window_origin = lds_tile.get_window_origin();
425 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
426 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
427 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
428
429 // loop over thread tensor space [y0, y1, ...]
430 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
432 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
433 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
434
435 auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
436 auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
437
438 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
439 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
440
441 // Use precomputed window origin
442 auto lds_bottom_tensor_thread_idx =
443 window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
444 // Use precomputed tensor descriptor
445 const auto lds_coord =
446 make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
447 // Calculate SMEM address using base pointer
448 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
449
450 // data index [y0, y1, ...]
451 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
452 constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
453 const auto page_offset = page_idx_[idx_gather];
454
455 // merge page_offset into bottom_coord
456 auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
457 mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
458
459 // read from bottom tensor
460 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
461 this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
462 smem,
463 mixed_bottom_thread_coord,
464 number<0>{},
466 else
467 this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
468 smem,
469 mixed_bottom_thread_coord,
470 number<0>{},
471 valids_[idx_gather],
473
474 // move thread coordinate
475 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
476 {
477 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
478
479 constexpr auto forward_step_scatter = generate_tuple(
480 [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
481 number<NDimY>{});
482
483 constexpr auto idx_diff_ps_ys = container_concat(
484 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
485 forward_step_scatter);
486 // lds_diff doesn't need to mask the difference of the gather-dim.
487 constexpr auto lds_idx_diff_ps_ys = container_concat(
488 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
489 idx_diff_ys);
490
492 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
494 lds_window_adaptor_thread_coord,
495 lds_bottom_tensor_thread_coord,
496 lds_idx_diff_ps_ys);
497 }
498 });
499 });
500 }
501
502 // TODO: currently async load only implemented in inline asm
503 template <typename LdsTileWindow_,
504 index_t i_access_unsupport_ = -1,
505 bool oob_conditional_check = true,
506 bool pre_nop = false>
507 CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
510 bool_constant<pre_nop> = {}) const
511 {
512 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
513 // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
514 using LdsDataType = typename LdsTileWindow::DataType;
515 // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
516
517 // issues * warps * lanes
518 static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
519
520 const index_t size_per_buf =
521 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
523 sizeof(LdsDataType);
524
525 const index_t size_per_wave =
526 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
528 sizeof(LdsDataType) -
529 size_per_buf;
530
531 const index_t size_per_issue =
532 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
534 sizeof(LdsDataType) -
535 size_per_buf;
536
537 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
538 m0_set_with_memory(m0_init_value); // This should be wave independent
539
540 using Traits = load_store_traits;
541
542 // using vector_type_t = typename Traits::vector_type_t;
543 using vector_t = typename Traits::vector_t;
544 using SFC_Ys = typename Traits::SFC_Ys;
545
546 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
547
548 // loop over thread tensor space [y0, y1, ...]
549 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
551 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
552 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
553
554 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
555 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
556 constexpr auto pre_nop_ = [&]() {
557 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
558 return bool_constant<true>{};
559 else
560 return bool_constant<false>{};
561 }();
562
563 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
564 constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
565 const auto page_offset = page_idx_[idx_gather];
566
567 // read from bottom tensor
568 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
569 {
570 get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
571 smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
572 }
573 else
574 {
575 get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
576 smem,
577 bottom_tensor_thread_coord,
578 page_offset,
579 valids_[idx_gather],
580 0,
581 pre_nop_);
582 }
583
584 // move thread coordinate
585 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
586 {
587 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
588
589 constexpr auto forward_step_scatter = generate_tuple(
590 [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
591 number<NDimY>{});
592
593 constexpr auto idx_diff_ps_ys = container_concat(
594 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
595 forward_step_scatter);
596
598 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
599
600 m0_inc_with_memory(size_per_issue);
601 }
602 });
603 });
604 }
605
606 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
610 {
611 using Traits = load_store_traits;
612
613 // using vector_type_t = typename Traits::vector_type_t;
614 using vector_t = typename Traits::vector_t;
615 using SFC_Ys = typename Traits::SFC_Ys;
616
617 constexpr auto tile_dstr = TileDstr{};
618
619 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
620 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
621 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
622
623 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
624 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
625
626 // data index [y0, y1, ...]
627 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
628 constexpr auto idx_gather = idx_ys_start[number<0>{}];
629 const auto page_offset = page_idx_[idx_gather];
630
631 // read from distributed tensor
632 vector_t vec_value;
633
634 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
635 constexpr auto idx_ys = generate_tuple(
636 [&](auto jj) {
637 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
638 : idx_ys_start[jj];
639 },
640 number<NDimY>{});
641
642 constexpr index_t d =
643 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
644 Traits::PackedSize;
645
646 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
647 dstr_tensor.get_thread_buffer().template at<d>();
648 });
649
650 // write into bottom tensor
651 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
652 {
653 get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
654 bottom_tensor_thread_coord,
655 page_offset,
656 vec_value,
658 }
659 else
660 {
661 get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
662 bottom_tensor_thread_coord,
663 page_offset,
664 valids_[idx_gather],
665 vec_value,
667 }
668
669 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
670 {
671 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
672
673 constexpr auto forward_step_scatter = generate_tuple(
674 [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
675 number<NDimY>{});
676
677 constexpr auto idx_diff_ps_ys = container_concat(
678 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
679 forward_step_scatter);
680
682 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
683 }
684 });
685 });
686 }
687
688 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
692 {
693 using Traits = load_store_traits;
694
695 // using vector_type_t = typename Traits::vector_type_t;
696 using vector_t = typename Traits::vector_t;
697 using SFC_Ys = typename Traits::SFC_Ys;
698
699 constexpr auto tile_dstr = TileDstr{};
700 // printf("off %d\n", page_idx_[I0]);
701 // loop over thread tensor space [y0, y1, ...]
702 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
703 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
704 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
705
706 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
707 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
708
709 // data index [y0, y1, ...]
710 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
711 constexpr auto idx_gather = idx_ys_start[number<0>{}];
712 const auto page_offset = page_idx_[idx_gather];
713
714 // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
715 // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
716
717 // read from distributed tensor
718 // vector_type_t vec;
719 vector_t vec_value;
720
721 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
722 constexpr auto idx_ys = generate_tuple(
723 [&](auto jj) {
724 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
725 : idx_ys_start[jj];
726 },
727 number<NDimY>{});
728
729 constexpr index_t d =
730 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
731 Traits::PackedSize;
732 // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
733 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
734 dstr_tensor.get_thread_buffer().template at<d>();
735 });
736
737 // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
738
739 // write into bottom tensor
740 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
741 {
742 get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
743 bottom_tensor_thread_coord,
744 page_offset,
745 vec_value,
747 }
748 else
749 {
750 get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
751 bottom_tensor_thread_coord,
752 page_offset,
753 valids_[idx_gather],
754 vec_value,
756 }
757
758 // printf("coord_offset:%d, scatter_offset:%d \n",
759 // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
760 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
761 {
762 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
763
764 constexpr auto forward_step_scatter = generate_tuple(
765 [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
766 number<NDimY>{});
767
768 constexpr auto idx_diff_ps_ys = container_concat(
769 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
770 forward_step_scatter);
771
773 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
774 }
775 });
776 });
777 }
778
779 // move thread's botom tensor coordiante
780 // [x0', x1', ... ] ==> [offset]
781 // also move window-origin
783 {
784 window_origin_ += step;
785 BottomTensorIndex step_new = step;
786 step_new(HsGatherDim) = 0;
787 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
788 move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
789 pre_computed_coords_(iCoord)(I1),
790 step_new);
791 });
792 }
793
794 CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
795
797 {
798 if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
799 {
800 valids_ = new_valids;
801 }
802 }
803
805 const ValidArray& new_valids)
806 {
807 update_page_idx(new_idx);
808 update_valids(new_valids);
809 }
810
812 {
813 window_origin_ = new_window_origin;
814
815#if 0 // debug
816 // TODO: this use more register for FA, but less register for GEMM
817 // need investigation
818 // only support warp-tile and block-tile
819 static_assert(NDimP == 1 or NDimP == 2, "wrong!");
820
821 WindowAdaptorCoord window_adaptor_thread_coord_tmp;
822
823 if constexpr(NDimP == 1)
824 {
825 window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
826 tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
827 }
828 else if constexpr(NDimP == 2)
829 {
830 window_adaptor_thread_coord_tmp =
831 make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
832 AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
833 }
834#else
835 // TODO: this use less register for FA, but more register for GEMM
836 // need investigation
837 const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
838 tile_dstr_.get_ps_ys_to_xs_adaptor(),
840#endif
841
842 BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
843 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
844
845 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
846 const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
847 bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
848
849 // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
850 // future load/store() calls (might allocate more registers)
851 using Traits = load_store_traits;
852 using SFC_Ys = typename Traits::SFC_Ys;
853
854 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
855 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
856 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
857
858 constexpr auto idx_diff_ys =
859 SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
860
861 constexpr auto idx_diff_ps_ys = container_concat(
862 generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
863
865 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
866
867 pre_computed_coords_(iCoord) =
868 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
869 });
870 }
871
873
874 // this is the bottom tensor view
875 // [x0', x1', ...] ==> [offset]
877
878 //
880
881 // origin ([x0', x1', ...]) of window on bottom tensor
883
884 // Tile tensor distribution, which contains:
885 // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
886 // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
888
891
892 // this contains:
893 // per-thread coordinate for window adaptor
894 // per-thread coordinate for bottom tensor
896};
897
898// TODO: use strategy
899template <typename TensorView_,
900 typename WindowLengths_,
901 typename StaticTileDistribution_,
902 typename StaticPageIndexArray_,
903 index_t HsGatherDim = 0,
904 index_t NumCoord = 1>
905CK_TILE_DEVICE constexpr auto
907 const WindowLengths_& window_lengths,
908 const multi_index<TensorView_::get_num_of_dimension()>& origin,
909 const StaticTileDistribution_& tile_distribution,
910 const StaticPageIndexArray_& page_idx,
912 number<NumCoord> = {})
913{
918 std::nullptr_t,
919 HsGatherDim,
920 NumCoord>{
921 tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
922}
923
924template <typename TensorView,
925 typename WindowLengths,
926 typename StaticTileDistribution,
927 typename StaticPageIndexArray,
928 index_t HsGatherDim>
931 const multi_index<TensorView::get_num_of_dimension()>& origin,
932 const StaticTileDistribution& tile_distribution,
933 const StaticPageIndexArray& page_idx,
935{
937 tile_window.get_window_lengths(),
938 origin,
939 tile_distribution,
940 page_idx,
942}
943
944template <typename TensorView,
945 typename WindowLengths,
946 typename StaticTileDistribution,
947 typename StaticPageIndexArray,
948 index_t HsGatherDim>
951 const StaticTileDistribution& tile_distribution,
952 const StaticPageIndexArray& page_idx,
954{
956 tile_window.get_window_lengths(),
957 tile_window.get_window_origin(),
958 tile_distribution,
959 page_idx,
961}
962
963template <typename TensorView_,
964 typename WindowLengths_,
965 typename StaticTileDistribution_,
966 typename StaticPageIndexArray_,
967 typename StaticValidArray_,
968 index_t HsGatherDim = 0,
969 index_t NumCoord = 1>
970CK_TILE_DEVICE constexpr auto
972 const WindowLengths_& window_lengths,
973 const multi_index<TensorView_::get_num_of_dimension()>& origin,
974 const StaticTileDistribution_& tile_distribution,
975 const StaticPageIndexArray_& page_idx,
976 const StaticValidArray_& valids,
978 number<NumCoord> = {})
979{
985 HsGatherDim,
986 NumCoord>{
987 tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
988}
989
990template <typename TensorView,
991 typename WindowLengths,
992 typename StaticTileDistribution,
993 typename StaticPageIndexArray,
994 typename StaticValidArray,
995 index_t HsGatherDim>
998 const multi_index<TensorView::get_num_of_dimension()>& origin,
999 const StaticTileDistribution& tile_distribution,
1000 const StaticPageIndexArray& page_idx,
1001 const StaticValidArray& valids,
1003{
1005 tile_window.get_window_lengths(),
1006 origin,
1007 tile_distribution,
1008 page_idx,
1009 valids,
1011}
1012
1013template <typename TensorView,
1014 typename WindowLengths,
1015 typename StaticTileDistribution,
1016 typename StaticPageIndexArray,
1017 typename StaticValidArray,
1018 index_t HsGatherDim>
1021 const StaticTileDistribution& tile_distribution,
1022 const StaticPageIndexArray& page_idx,
1023 const StaticValidArray& valids,
1025{
1027 tile_window.get_window_lengths(),
1028 tile_window.get_window_origin(),
1029 tile_distribution,
1030 page_idx,
1031 valids,
1033}
1034
1035template <typename NewTensorView_,
1036 typename OldTensorView_,
1037 typename WindowLengths_,
1038 typename StaticTileDistribution_,
1039 typename StaticPageIndexArray_,
1040 typename StaticValidArray_,
1041 index_t HsGatherDim = 0,
1042 index_t NumCoord = 1>
1043CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1044 const tile_scatter_gather<OldTensorView_,
1045 WindowLengths_,
1046 StaticTileDistribution_,
1047 StaticPageIndexArray_,
1048 StaticValidArray_,
1049 HsGatherDim,
1050 NumCoord>& tile_window)
1051{
1052 return make_tile_scatter_gather(new_tensor_view,
1053 tile_window.window_lengths_,
1054 tile_window.window_origin_,
1055 tile_window.tile_dstr_,
1056 tile_window.page_idx_,
1057 tile_window.valids_);
1058}
1059
1060} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition tile_distribution.hpp:22
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
typename std::remove_reference< T >::type remove_reference_t
Definition type_traits.hpp:15
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition tile_scatter_gather.hpp:1043
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition tensor_adaptor_coordinate.hpp:97
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto generate_array(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1115
CK_TILE_HOST_DEVICE constexpr void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition tile/core/container/container_helper.hpp:420
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
CK_TILE_HOST_DEVICE constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition tensor_coordinate.hpp:72
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition tile_scatter_gather.hpp:906
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition utility.hpp:19
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition utility.hpp:25
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition tensor_coordinate.hpp:60
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition tile/core/container/sequence.hpp:287
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
static constexpr bool value
Definition type_traits.hpp:77
Definition tile/core/numeric/numeric.hpp:81
Definition space_filling_curve.hpp:20
Definition static_distributed_tensor.hpp:21
CK_TILE_HOST_DEVICE constexpr const auto & get_thread_buffer() const
Definition static_distributed_tensor.hpp:58
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
Definition tile/core/utility/debug.hpp:67
Definition tile_distribution.hpp:72
CK_TILE_HOST_DEVICE constexpr const auto & get_ps_ys_to_xs_adaptor() const
Definition tile_distribution.hpp:126
Definition tile_scatter_gather.hpp:82
static constexpr index_t PackedSize
Definition tile_scatter_gather.hpp:105
static constexpr index_t NumAccess
Definition tile_scatter_gather.hpp:144
thread_buffer< DataType, ScalarPerVector/PackedSize > vector_t
Definition tile_scatter_gather.hpp:113
decltype(get_space_filling_curve()) SFC_Ys
Definition tile_scatter_gather.hpp:142
static constexpr index_t VectorDimY
Definition tile_scatter_gather.hpp:107
static constexpr index_t ScalarPerVector
Definition tile_scatter_gather.hpp:108
This class provides tile (windowed) view and access to the device memory.
Definition tile_scatter_gather.hpp:42
CK_TILE_DEVICE constexpr auto get_window_lengths() const
Definition tile_scatter_gather.hpp:232
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition tile_scatter_gather.hpp:782
static constexpr index_t NumAccessPerCoord
Definition tile_scatter_gather.hpp:150
static constexpr auto I1
Definition tile_scatter_gather.hpp:60
BottomTensorIndex window_origin_
Definition tile_scatter_gather.hpp:882
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_scatter_gather.hpp:309
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition tile_scatter_gather.hpp:78
WindowLengths window_lengths_
Definition tile_scatter_gather.hpp:879
static constexpr index_t NDimBottomTensor
Definition tile_scatter_gather.hpp:54
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition tile_scatter_gather.hpp:75
CK_TILE_DEVICE constexpr auto get_tile_distribution() const
Definition tile_scatter_gather.hpp:234
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition tile_scatter_gather.hpp:73
PageIdxArray page_idx_
Definition tile_scatter_gather.hpp:889
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_scatter_gather.hpp:411
remove_cvref_t< WindowLengths_ > WindowLengths
Definition tile_scatter_gather.hpp:44
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition tile_scatter_gather.hpp:811
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition tile_scatter_gather.hpp:895
static CK_TILE_DEVICE constexpr index_t get_num_of_dimension()
Definition tile_scatter_gather.hpp:225
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition tile_scatter_gather.hpp:45
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition tile_scatter_gather.hpp:249
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_scatter_gather.hpp:321
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_scatter_gather.hpp:689
CK_TILE_DEVICE constexpr void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition tile_scatter_gather.hpp:241
static CK_TILE_DEVICE constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
Definition tile_scatter_gather.hpp:267
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition tile_scatter_gather.hpp:804
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition tile_scatter_gather.hpp:49
TileDstr tile_dstr_
Definition tile_scatter_gather.hpp:887
ValidArray valids_
Definition tile_scatter_gather.hpp:890
static constexpr index_t NDimY
Definition tile_scatter_gather.hpp:57
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition tile_scatter_gather.hpp:51
static constexpr index_t NDimWindowAdaptorTop
Definition tile_scatter_gather.hpp:53
remove_cvref_t< StaticValidArray_ > ValidArray
Definition tile_scatter_gather.hpp:47
static constexpr index_t NDimP
Definition tile_scatter_gather.hpp:56
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition tile_scatter_gather.hpp:43
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_scatter_gather.hpp:607
CK_TILE_DEVICE constexpr auto get_window_origin() const
Definition tile_scatter_gather.hpp:238
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition tile_scatter_gather.hpp:46
CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition tile_scatter_gather.hpp:154
CK_TILE_HOST_DEVICE void init_raw()
Definition tile_scatter_gather.hpp:872
static constexpr auto I0
Definition tile_scatter_gather.hpp:59
CK_TILE_DEVICE constexpr auto get_num_of_access() const
Definition tile_scatter_gather.hpp:306
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition tile_scatter_gather.hpp:48
BottomTensorView bottom_tensor_view_
Definition tile_scatter_gather.hpp:876
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition tile_scatter_gather.hpp:796
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const
Definition tile_scatter_gather.hpp:236
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_scatter_gather.hpp:507
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition tile_scatter_gather.hpp:72
CK_TILE_DEVICE constexpr tile_scatter_gather()=default
static CK_TILE_DEVICE constexpr bool has_static_tile_distribution()
Definition tile_scatter_gather.hpp:227
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition tile_scatter_gather.hpp:794
CK_TILE_DEVICE constexpr auto get_window_origin() const
Definition tile_window_base.hpp:45
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const
Definition tile_window_base.hpp:47
BottomTensorIndex window_origin_
Definition tile_window_base.hpp:79
CK_TILE_DEVICE constexpr auto get_window_lengths() const
Definition tile_window_base.hpp:46
WindowLengths window_lengths_
Definition tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition tile_window.hpp:1016
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10