device_contraction_multiple_d_xdl_cshuffle.hpp Source File

device_contraction_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_contraction_multiple_d_xdl_cshuffle.hpp Source File
device_contraction_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22
23template <typename GridwiseGemm,
24 typename FloatAB,
25 typename FloatDsPointer,
26 typename FloatE,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename Block2ETileMap,
35 bool HasMainKBlockLoop>
36__global__ void
37#if CK_USE_LAUNCH_BOUNDS
39#endif
41 const FloatAB* __restrict__ p_a_grid,
42 const FloatAB* __restrict__ p_b_grid,
43 FloatDsPointer p_ds_grid,
44 FloatE* __restrict__ p_e_grid,
45 const AElementwiseOperation a_element_op,
46 const BElementwiseOperation b_element_op,
47 const CDEElementwiseOperation cde_element_op,
48 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
49 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
50 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
51 ds_grid_desc_mblock_mperblock_nblock_nperblock,
52 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 e_grid_desc_mblock_mperblock_nblock_nperblock,
54 const Block2ETileMap block_2_etile_map)
55{
56#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
57 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
58 {
59 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
60
61 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
62 p_a_grid,
63 p_b_grid,
64 p_ds_grid,
65 p_e_grid,
66 p_shared,
67 a_element_op,
68 b_element_op,
69 cde_element_op,
70 a_grid_desc_ak0_m_ak1,
71 b_grid_desc_bk0_n_bk1,
72 ds_grid_desc_mblock_mperblock_nblock_nperblock,
73 e_grid_desc_mblock_mperblock_nblock_nperblock,
74 block_2_etile_map);
75 }
76#else
77 ignore = p_a_grid;
78 ignore = p_b_grid;
79 ignore = p_ds_grid;
80 ignore = p_e_grid;
81 ignore = a_element_op;
82 ignore = b_element_op;
83 ignore = cde_element_op;
84 ignore = a_grid_desc_ak0_m_ak1;
85 ignore = b_grid_desc_bk0_n_bk1;
86 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
87 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
88 ignore = block_2_etile_map;
89#endif
90}
91
92} // namespace ck
93
94namespace ck {
95namespace tensor_operation {
96namespace device {
97
98// Tensor Contraction:
99// input : A
100// input : B
101// input : D0, D1, ...
102// output : E
103// C = a_op(A) * b_op(B)
104// E = cde_op(C, D0, D1, ...)
105// Assume:
106// A[M0, M1, M2, ..., K0, K1, K2, ...]
107// B[N0, N1, N2, ..., K0, K1, K2, ...]
108// D[M0, M1, M2, ..., N0, N1, N2, ...]
109// E[M0, M1, M2, ..., N0, N1, N2, ...]
110template <index_t NumDimM,
111 index_t NumDimN,
112 index_t NumDimK,
113 typename ADataType,
114 typename BDataType,
115 typename AccDataType,
116 typename CShuffleDataType,
117 typename DsDataType,
118 typename EDataType,
119 typename AElementwiseOperation,
120 typename BElementwiseOperation,
121 typename CDEElementwiseOperation,
122 GemmSpecialization GemmSpec,
123 index_t NumGemmKPrefetchStage,
124 index_t BlockSize,
125 index_t MPerBlock,
126 index_t NPerBlock,
127 index_t KPerBlock,
128 index_t AK1,
129 index_t BK1,
130 index_t MPerXDL,
131 index_t NPerXDL,
132 index_t MXdlPerWave,
133 index_t NXdlPerWave,
134 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
135 typename ABlockTransferThreadClusterArrangeOrder,
136 typename ABlockTransferSrcAccessOrder,
137 index_t ABlockTransferSrcVectorDim,
138 index_t ABlockTransferSrcScalarPerVector,
139 index_t ABlockTransferDstScalarPerVector_AK1,
140 bool ABlockLdsExtraM,
141 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
142 typename BBlockTransferThreadClusterArrangeOrder,
143 typename BBlockTransferSrcAccessOrder,
144 index_t BBlockTransferSrcVectorDim,
145 index_t BBlockTransferSrcScalarPerVector,
146 index_t BBlockTransferDstScalarPerVector_BK1,
147 bool BBlockLdsExtraN,
148 index_t CShuffleMXdlPerWavePerShuffle,
149 index_t CShuffleNXdlPerWavePerShuffle,
150 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151 index_t CDEBlockTransferScalarPerVector_NPerBlock,
152 typename ComputeDataType = ADataType,
155 : public DeviceContractionMultipleD<NumDimM,
156 NumDimN,
157 NumDimK,
158 ADataType,
159 BDataType,
160 DsDataType,
161 EDataType,
162 AElementwiseOperation,
163 BElementwiseOperation,
164 CDEElementwiseOperation,
165 ComputeDataType>
166{
168
170 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
171 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
172
173 static constexpr index_t NumDTensor = DsDataType::Size();
174
175 static constexpr auto I0 = Number<0>{};
176 static constexpr auto I1 = Number<1>{};
177 static constexpr auto I2 = Number<2>{};
178 static constexpr auto I3 = Number<3>{};
179
180 static constexpr auto matrix_padder =
181 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
182
183 // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
184 static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_vec,
185 const std::vector<index_t>& a_ms_ks_strides_vec)
186 {
187 assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
188 a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
189
190 const auto to_tuple = [&](auto& vec, auto num) {
191 return generate_tuple([&](auto i) { return vec[i]; }, num);
192 };
193
194 const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
195 const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
196
197 // dimension Ids for M0, M1, ...
198 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
199
200 // dimension Ids for K0, K1, ...
201 constexpr auto kDimIds =
203
204 // lengths for M0, M1, ...
205 const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
206
207 // lengths for K0, K1, ...
208 const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
209
210 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
211 const auto a_grid_desc_ms_ks =
212 make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
213
214 // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
215 const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
216 a_grid_desc_ms_ks,
218 make_tuple(mDimIds, kDimIds),
220
221 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
222 }
223
224 // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
225 static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_vec,
226 const std::vector<index_t>& b_ns_ks_strides_vec)
227 {
228 assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
229 b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
230
231 const auto to_tuple = [&](auto& vec, auto num) {
232 return generate_tuple([&](auto i) { return vec[i]; }, num);
233 };
234
235 const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number<NumDimN + NumDimK>{});
236 const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number<NumDimN + NumDimK>{});
237
238 // dimension Ids for N0, N1, ...
239 constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
240
241 // dimension Ids for K0, K1, ...
242 constexpr auto kDimIds =
244
245 // lengths for K0, K1, ...
246 const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
247
248 // lengths for N0, N1, ...
249 const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
250
251 // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
252 const auto b_grid_desc_ns_ks =
253 make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
254
255 // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
256 const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
257 b_grid_desc_ns_ks,
259 make_tuple(nDimIds, kDimIds),
261
262 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
263 }
264
265 // assume E[M0, M1, M2, ..., N0, N1, N2...]
266 static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_vec,
267 const std::vector<index_t>& e_ms_ns_strides_vec)
268 {
269 assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
270 e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
271
272 const auto to_tuple = [&](auto& vec, auto num) {
273 return generate_tuple([&](auto i) { return vec[i]; }, num);
274 };
275
276 const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number<NumDimM + NumDimN>{});
277 const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number<NumDimM + NumDimN>{});
278
279 // dimension Ids for M0, M1, ...
280 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
281
282 // dimension Ids for N0, N1, ...
283 constexpr auto nDimIds =
285
286 // lengths for M0, M1, ...
287 const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
288
289 // lengths for K0, K1, ...
290 const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
291
292 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
293 const auto e_grid_desc_ms_ns =
294 make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
295
296 // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
297 const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
298 e_grid_desc_ms_ns,
300 make_tuple(mDimIds, nDimIds),
302
303 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
304 }
305
307 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths_vec,
308 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides_vec)
309 {
310 return generate_tuple(
311 [&](auto i) {
312 return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i],
313 ds_ms_ns_strides_vec[i]);
314 },
316 }
317
318 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
319 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
321 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
322
323 // GridwiseGemm
324 template <index_t NXdlPerWave_>
326 ADataType, // TODO: distinguish A/B datatype
327 BDataType,
328 ComputeDataType,
329 AccDataType,
330 CShuffleDataType,
331 DsDataType,
332 EDataType,
333 AElementwiseOperation,
334 BElementwiseOperation,
335 CDEElementwiseOperation,
336 NumGemmKPrefetchStage,
337 BlockSize,
338 MPerBlock,
339 NPerBlock,
340 KPerBlock,
341 AK1,
342 BK1,
343 MPerXDL,
344 NPerXDL,
345 MXdlPerWave,
346 NXdlPerWave_,
347 ABlockTransferThreadClusterLengths_AK0_M_AK1,
348 ABlockTransferThreadClusterArrangeOrder,
349 ABlockTransferSrcAccessOrder,
350 ABlockTransferSrcVectorDim,
351 ABlockTransferSrcScalarPerVector,
352 ABlockTransferDstScalarPerVector_AK1,
353 false,
354 ABlockLdsExtraM,
355 BBlockTransferThreadClusterLengths_BK0_N_BK1,
356 BBlockTransferThreadClusterArrangeOrder,
357 BBlockTransferSrcAccessOrder,
358 BBlockTransferSrcVectorDim,
359 BBlockTransferSrcScalarPerVector,
360 BBlockTransferDstScalarPerVector_BK1,
361 false,
362 BBlockLdsExtraN,
363 CShuffleMXdlPerWavePerShuffle,
364 CShuffleNXdlPerWavePerShuffle,
365 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
366 CDEBlockTransferScalarPerVector_NPerBlock,
367 LoopSched>;
370
371 // desc for blockwise copy
374 AGridDesc_M_K{}))>;
377 BGridDesc_N_K{}))>;
380 DsGridDesc_M_N{}))>;
383 EGridDesc_M_N{}))>;
384
385 // block-to-e-tile map
388
389 // Argument
390 struct Argument : public BaseArgument
391 {
392 Argument(const void* p_a_grid,
393 const void* p_b_grid,
394 std::array<const void*, NumDTensor> p_ds_grid,
395 void* p_e_grid,
396 const std::vector<index_t>& a_ms_ks_lengths,
397 const std::vector<index_t>& a_ms_ks_strides,
398 const std::vector<index_t>& b_ns_ks_lengths,
399 const std::vector<index_t>& b_ns_ks_strides,
400 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
401 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
402 const std::vector<index_t>& e_ms_ns_lengths,
403 const std::vector<index_t>& e_ms_ns_strides,
404 AElementwiseOperation a_element_op,
405 BElementwiseOperation b_element_op,
406 CDEElementwiseOperation cde_element_op)
407 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
408 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
409 p_ds_grid_{},
410 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
411 a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
412 b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
414 e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
416 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
418 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
419 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
420 a_element_op_{a_element_op},
421 b_element_op_{b_element_op},
422 cde_element_op_{cde_element_op}
423 {
424 // populate pointer, batch stride, desc for Ds
425 static_for<0, NumDTensor, 1>{}([&](auto i) {
426 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
427
428 // D pointer
429 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
430
431 // D desc
433 DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
434 });
435
436 // for sanity check of vector memory access
438 CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
439
441 CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
442
443 for(index_t i = 0; i < NumDTensor; ++i)
444 {
446 CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
447 }
448
450 CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
451 }
452
453 void Print() const
454 {
455 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
456 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
458 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
459 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
460 }
461
462 // private:
463 // pointers
464 const ADataType* p_a_grid_;
465 const BDataType* p_b_grid_;
467 EDataType* p_e_grid_;
468
469 // tensor descriptors for problem definiton
474
475 // tensor descriptors for block/thread-wise copy
481
482 // block-to-e-tile map
484
485 // element-wise op
486 AElementwiseOperation a_element_op_;
487 BElementwiseOperation b_element_op_;
488 CDEElementwiseOperation cde_element_op_;
489
490 // Describe whether the last part of a given dimension of A/B/D/E is continues dim.
493 std::array<index_t, NumDTensor> ds_continous_dim_;
495
498 std::array<index_t, NumDTensor> ds_max_read_elems_;
500 };
501
502 // Invoker
503 struct Invoker : public BaseInvoker
504 {
506
507 template <typename GridwiseGemm>
508 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
509 {
510 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
515 {
516 throw std::runtime_error(
517 "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
518 }
519 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
520 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
521 arg.e_grid_desc_m_n_);
522
523 auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
524 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
526 const index_t grid_size =
527 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
528
529 const auto K =
530 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
531
532 auto launch_kernel = [&](auto has_main_k_block_loop) {
533 constexpr bool has_main_loop = has_main_k_block_loop.value;
534
536 GridwiseGemm,
537 ADataType, // TODO: distiguish A/B datatype
538 typename GridwiseGemm::DsGridPointer,
539 EDataType,
540 AElementwiseOperation,
541 BElementwiseOperation,
542 CDEElementwiseOperation,
548 has_main_loop>;
549
550 return launch_and_time_kernel(stream_config,
551 kernel,
552 dim3(grid_size),
553 dim3(BlockSize),
554 0,
555 arg.p_a_grid_,
556 arg.p_b_grid_,
557 arg.p_ds_grid_,
558 arg.p_e_grid_,
559 arg.a_element_op_,
560 arg.b_element_op_,
561 arg.cde_element_op_,
564 ds_grid_desc_mblock_mperblock_nblock_nperblock,
565 e_grid_desc_mblock_mperblock_nblock_nperblock,
567 };
568
569 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
570 {
571 return launch_kernel(integral_constant<bool, true>{});
572 }
573 else
574 {
575 return launch_kernel(integral_constant<bool, false>{});
576 }
577 }
578
580
581 // polymorphic
582 float Run(const BaseArgument* p_arg,
583 const StreamConfig& stream_config = StreamConfig{}) override
584 {
585 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
586 }
587 };
588
589 static bool IsSupportedArgument(const Argument& arg)
590 {
592 {
593 return false;
594 }
595 if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
596 {
597 return false;
598 }
599 bool valid = false;
600 if(get_warp_size() == 64)
601 {
602 if constexpr(NXdlPerWave64 > 0)
603 {
609 }
610 }
611 else
612 {
613 if constexpr(NXdlPerWave32 > 0)
614 {
620 }
621 }
622
623 if(!valid)
624 {
625 return false;
626 }
627
628 // check vector access
629 static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
630 (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
631 "wrong!");
632
633 const bool valid_a_vector_size =
634 arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
635 const bool valid_a_access_dim_m =
636 ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
637 const bool valid_a_access_dim_k =
638 ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
639 const bool valid_a_access_dim =
640 valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
641 if(!(valid_a_vector_size && valid_a_access_dim))
642 {
643 return false;
644 }
645
646 const bool valid_b_vector_size =
647 arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
648 const bool valid_b_access_dim_n =
649 BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
650 const bool valid_b_access_dim_k =
651 BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
652 const bool valid_b_access_dim =
653 valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
654 if(!(valid_b_vector_size && valid_b_access_dim))
655 {
656 return false;
657 }
658
659 bool valid_ds_access = true;
660 static_for<0, NumDTensor, 1>{}([&](auto i) {
661 const bool valid_d_vector_size =
662 arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
663 // Vector read of Ds is always on N dimension.
664 const bool valid_d_access_dim =
665 arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
666 if(!(valid_d_vector_size && valid_d_access_dim))
667 {
668 valid_ds_access = false;
669 }
670 });
671 if(!valid_ds_access)
672 {
673 return false;
674 }
675
676 const bool valid_e_vector_size =
677 arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
678 // Vector write of E is always on N dimension.
679 const bool valid_e_access_dim =
680 arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
681 if(!(valid_e_vector_size && valid_e_access_dim))
682 {
683 return false;
684 }
685
686 return true;
687 }
688
689 // polymorphic
690 bool IsSupportedArgument(const BaseArgument* p_arg) override
691 {
692 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
693 }
694
695 static auto MakeArgument(const void* p_a,
696 const void* p_b,
697 std::array<const void*, NumDTensor> p_ds,
698 void* p_e,
699 const std::vector<index_t>& a_ms_ks_lengths,
700 const std::vector<index_t>& a_ms_ks_strides,
701 const std::vector<index_t>& b_ns_ks_lengths,
702 const std::vector<index_t>& b_ns_ks_strides,
703 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
704 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
705 const std::vector<index_t>& e_ms_ns_lengths,
706 const std::vector<index_t>& e_ms_ns_strides,
707 AElementwiseOperation a_element_op,
708 BElementwiseOperation b_element_op,
709 CDEElementwiseOperation cde_element_op)
710 {
711 return Argument{p_a,
712 p_b,
713 p_ds,
714 p_e,
715 a_ms_ks_lengths,
716 a_ms_ks_strides,
717 b_ns_ks_lengths,
718 b_ns_ks_strides,
719 ds_ms_ns_lengths,
720 ds_ms_ns_strides,
721 e_ms_ns_lengths,
722 e_ms_ns_strides,
723 a_element_op,
724 b_element_op,
725 cde_element_op};
726 }
727
728 static auto MakeInvoker() { return Invoker{}; }
729
730 // polymorphic
731 std::unique_ptr<BaseArgument>
732 MakeArgumentPointer(const void* p_a,
733 const void* p_b,
734 std::array<const void*, NumDTensor> p_ds,
735 void* p_e,
736 const std::vector<index_t>& a_ms_ks_lengths,
737 const std::vector<index_t>& a_ms_ks_strides,
738 const std::vector<index_t>& b_ns_ks_lengths,
739 const std::vector<index_t>& b_ns_ks_strides,
740 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
741 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
742 const std::vector<index_t>& e_ms_ns_lengths,
743 const std::vector<index_t>& e_ms_ns_strides,
744 AElementwiseOperation a_element_op,
745 BElementwiseOperation b_element_op,
746 CDEElementwiseOperation cde_element_op) override
747 {
748 return std::make_unique<Argument>(p_a,
749 p_b,
750 p_ds,
751 p_e,
752 a_ms_ks_lengths,
753 a_ms_ks_strides,
754 b_ns_ks_lengths,
755 b_ns_ks_strides,
756 ds_ms_ns_lengths,
757 ds_ms_ns_strides,
758 e_ms_ns_lengths,
759 e_ms_ns_strides,
760 a_element_op,
761 b_element_op,
762 cde_element_op);
763 }
764
765 // polymorphic
766 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
767 {
768 return std::make_unique<Invoker>(Invoker{});
769 }
770
771 // polymorphic
772 std::string GetTypeString() const override
773 {
774 auto str = std::stringstream();
775
776 // clang-format off
777 str << "DeviceContractionMultipleD_Xdl_CShuffle"
778 << "<"
779 << NumDimM << ", "
780 << NumDimN << ", "
781 << NumDimK << ", "
782 << BlockSize << ", "
783 << MPerBlock << ", "
784 << NPerBlock << ", "
785 << KPerBlock << ", "
786 << AK1 << ", "
787 << BK1 << ", "
788 << ABlockTransferSrcVectorDim << ", "
789 << BBlockTransferSrcVectorDim
790 << ">";
791 // clang-format on
792
793 return str.str();
794 }
795};
796
797} // namespace device
798} // namespace tensor_operation
799} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_contraction_utils.hpp:33
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
__global__ void kernel_contraction_multiple_d_xdl_cshuffle(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatDsPointer p_ds_grid, FloatE *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:41
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:391
index_t a_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:491
BGridDesc_N_K b_grid_desc_n_k_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:471
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::vector< index_t > &a_ms_ks_lengths, const std::vector< index_t > &a_ms_ks_strides, const std::vector< index_t > &b_ns_ks_lengths, const std::vector< index_t > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_lengths, const std::vector< index_t > &e_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:392
index_t e_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:494
index_t b_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:492
CDEElementwiseOperation cde_element_op_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:488
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:477
index_t e_max_write_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:499
EGridDesc_M_N e_grid_desc_m_n_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:473
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:480
AElementwiseOperation a_element_op_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:486
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:472
void Print() const
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:453
const ADataType * p_a_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:464
Block2ETileMap block_2_etile_map_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:483
index_t b_max_read_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:497
const BDataType * p_b_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:465
EDataType * p_e_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:467
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:476
AGridDesc_M_K a_grid_desc_m_k_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:470
std::array< index_t, NumDTensor > ds_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:493
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:479
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:466
index_t a_max_read_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:496
BElementwiseOperation b_element_op_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:487
std::array< index_t, NumDTensor > ds_max_read_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:498
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:504
DeviceOp::Argument Argument
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:505
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:508
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:582
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:166
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_ms_ns_lengths_vec, const std::vector< index_t > &e_ms_ns_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:266
DeviceContractionMultipleD_Xdl_CShuffle DeviceOp
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:167
decltype(MakeBGridDescriptor_N_K({}, {})) BGridDesc_N_K
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:319
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:378
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:321
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_ns_ks_lengths_vec, const std::vector< index_t > &b_ns_ks_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:225
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_ms_ks_lengths, const std::vector< index_t > &a_ms_ks_strides, const std::vector< index_t > &b_ns_ks_lengths, const std::vector< index_t > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_lengths, const std::vector< index_t > &e_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:695
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:381
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:325
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:375
static constexpr auto NXdlPerWave32
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:171
decltype(MakeAGridDescriptor_M_K({}, {})) AGridDesc_M_K
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:318
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:170
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:690
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_ms_ks_lengths_vec, const std::vector< index_t > &a_ms_ks_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:184
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:766
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:369
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:368
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_ms_ks_lengths, const std::vector< index_t > &a_ms_ks_strides, const std::vector< index_t > &b_ns_ks_lengths, const std::vector< index_t > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_lengths, const std::vector< index_t > &e_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:732
std::string GetTypeString() const override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:772
static constexpr auto matrix_padder
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:180
static constexpr index_t NumDTensor
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:173
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))> DsGridDesc_M_N
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:320
static constexpr auto I1
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:176
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:306
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:372
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:386
static constexpr auto I2
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:177
static bool IsSupportedArgument(const Argument &arg)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:589
static constexpr auto I3
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:178
static constexpr auto I0
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:175
static auto MakeInvoker()
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:728
Definition device_contraction_multiple_d.hpp:39
Definition matrix_padder.hpp:180