device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp Source File

device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp Source File
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <sstream>
11
14#include "ck/utility/env.hpp"
31#ifdef CK_EXPERIMENTAL_BUILDER
32#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
33#endif
34
35namespace ck {
36namespace tensor_operation {
37namespace device {
38
39namespace {
40
41/*
42 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
43 *
44 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
45 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
46 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
47 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
48 * limitations.
49 *
50 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
51 * returns the 2D index of the tile that it computes. \see
52 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
53 *
54 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
55 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
56 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
57 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
58 * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
59 * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
60 *
61 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
62 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
63 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
64 *
65 */
66template <typename GridwiseGemm,
67 typename ComputePtrOffset,
68 typename AGridDesc_AK0_M_K1,
69 typename BGridDesc_BK0_N_K1,
70 typename DsGridDesc_M_N,
71 typename EGridDesc_M_N,
72 bool HasMainKBlockLoop,
73 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
74 index_t MinimumOccupancy = 1,
76__global__ void
77#if CK_USE_LAUNCH_BOUNDS
78__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
79#endif
80 kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
81 const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
82 const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
83 const DsGridDesc_M_N ds_grid_desc_m_n,
84 const EGridDesc_M_N c_grid_desc_m_n,
85 const ComputePtrOffset compute_ptr_offset_of_groups,
86 const ComputePtrOffset compute_ptr_offset_of_n)
87{
88#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
90 {
91 // offset base pointer for each work-group
92 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
93 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
94
95 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
96 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
97
98 static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
99 using DsGridPointer = typename GridwiseGemm::DsGridPointer;
100 DsGridPointer p_ds_grid_grp{};
101
102 static_for<0, NumDTensor, 1>{}([&](auto i) {
103 p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
104 });
105
106 const long_index_t a_group_offset =
107 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
108 const long_index_t b_group_offset =
109 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
110 const long_index_t e_group_offset =
111 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
112
113 const long_index_t a_n_offset =
114 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
115 const long_index_t e_n_offset =
116 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
117
118 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
119
120 using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault;
121 const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
122
123 if constexpr(GridwiseGemm::DirectLoadEnabled)
124 {
125#if defined(__gfx950__)
126 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
127 karg.p_a_grid + a_group_offset + a_n_offset,
128 karg.p_b_grid + b_group_offset,
129 p_ds_grid_grp,
130 karg.p_c_grid + e_group_offset + e_n_offset,
131 p_shared,
132 karg,
133 karg.a_element_op,
134 karg.b_element_op,
135 karg.c_element_op,
136 block_2_ctile_map,
137 GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
138 GridwiseGemm::AK0Number,
139 GridwiseGemm::AK1Number>(
140 a_grid_desc_ak0_m_ak1),
141 GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
142 GridwiseGemm::BK0Number,
143 GridwiseGemm::BK1Number>(
144 b_grid_desc_bk0_n_bk1),
145 ds_grid_desc_m_n,
146 c_grid_desc_m_n);
147#endif
148 }
149 else
150 {
151 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
152 karg.p_a_grid + a_group_offset + a_n_offset,
153 karg.p_b_grid + b_group_offset,
154 p_ds_grid_grp,
155 karg.p_c_grid + e_group_offset + e_n_offset,
156 p_shared,
157 karg,
158 karg.a_element_op,
159 karg.b_element_op,
160 karg.c_element_op,
161 block_2_ctile_map,
162 GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
163 GridwiseGemm::AK0Number,
164 GridwiseGemm::AK1Number>(
165 a_grid_desc_ak0_m_ak1),
166 GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
167 GridwiseGemm::BK0Number,
168 GridwiseGemm::BK1Number>(
169 b_grid_desc_bk0_n_bk1),
170 ds_grid_desc_m_n,
171 c_grid_desc_m_n);
172 }
173 }
174#else
175 ignore = karg;
176 ignore = a_grid_desc_ak0_m_ak1;
177 ignore = b_grid_desc_bk0_n_bk1;
178 ignore = ds_grid_desc_m_n;
179 ignore = c_grid_desc_m_n;
180 ignore = compute_ptr_offset_of_groups;
181 ignore = compute_ptr_offset_of_n;
182#endif // end of if (defined(__gfx9__))
183}
184
185template <typename GridwiseGemm,
186 typename ComputePtrOffset,
187 typename AGridDesc_AK0_M_K1,
188 typename BGridDesc_BK0_N_K1,
189 typename DsGridDesc_M_N,
190 typename EGridDesc_M_N,
191 bool HasMainKBlockLoop,
192 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
193 index_t MinimumOccupancy = 1,
195__global__ void
196#if CK_USE_LAUNCH_BOUNDS
197__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
198#endif
199 kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
200 typename GridwiseGemm::Argument karg,
201 const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
202 const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
203 const DsGridDesc_M_N ds_grid_desc_m_n,
204 const EGridDesc_M_N c_grid_desc_m_n,
205 const ComputePtrOffset compute_ptr_offset_of_groups,
206 const ComputePtrOffset compute_ptr_offset_of_n)
207{
208#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
209 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
210 {
211 // offset base pointer for each work-group
212 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
213 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
214
215 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
216 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
217
218 static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
219 using DsGridPointer = typename GridwiseGemm::DsGridPointer;
220 DsGridPointer p_ds_grid_grp{};
221
222 static_for<0, NumDTensor, 1>{}([&](auto i) {
223 p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
224 });
225
226 const long_index_t a_group_offset =
227 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
228 const long_index_t b_group_offset =
229 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
230 const long_index_t e_group_offset =
231 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
232
233 const long_index_t a_n_offset =
234 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
235 const long_index_t e_n_offset =
236 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
237
238 // Pass two lds pointer is the key to tell compiler that ds_read/write
239 // operate on different lds chunk at same time without order dependecy
240 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
241 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
242
243 using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault;
244 const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
245
246 if constexpr(GridwiseGemm::DirectLoadEnabled)
247 {
248#if defined(__gfx950__)
249 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
250 karg.p_a_grid + a_group_offset + a_n_offset,
251 karg.p_b_grid + b_group_offset,
252 p_ds_grid_grp,
253 karg.p_c_grid + e_group_offset + e_n_offset,
254 p_shared_0,
255 p_shared_1,
256 karg,
257 karg.a_element_op,
258 karg.b_element_op,
259 karg.c_element_op,
260 block_2_ctile_map,
261 GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
262 GridwiseGemm::AK0Number,
263 GridwiseGemm::AK1Number>(
264 a_grid_desc_ak0_m_ak1),
265 GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
266 GridwiseGemm::BK0Number,
267 GridwiseGemm::BK1Number>(
268 b_grid_desc_bk0_n_bk1),
269 ds_grid_desc_m_n,
270 c_grid_desc_m_n);
271#endif
272 }
273 else
274 {
275 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
276 karg.p_a_grid + a_group_offset + a_n_offset,
277 karg.p_b_grid + b_group_offset,
278 p_ds_grid_grp,
279 karg.p_c_grid + e_group_offset + e_n_offset,
280 p_shared_0,
281 p_shared_1,
282 karg,
283 karg.a_element_op,
284 karg.b_element_op,
285 karg.c_element_op,
286 block_2_ctile_map,
287 GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
288 GridwiseGemm::AK0Number,
289 GridwiseGemm::AK1Number>(
290 a_grid_desc_ak0_m_ak1),
291 GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
292 GridwiseGemm::BK0Number,
293 GridwiseGemm::BK1Number>(
294 b_grid_desc_bk0_n_bk1),
295 ds_grid_desc_m_n,
296 c_grid_desc_m_n);
297 }
298 }
299#else
300 ignore = karg;
301 ignore = a_grid_desc_ak0_m_ak1;
302 ignore = b_grid_desc_bk0_n_bk1;
303 ignore = ds_grid_desc_m_n;
304 ignore = c_grid_desc_m_n;
305 ignore = compute_ptr_offset_of_groups;
306 ignore = compute_ptr_offset_of_n;
307#endif // end of if (defined(__gfx9__))
308}
309
310} // namespace
311
312template <typename T>
313using is_tuple = decltype(std::declval<T&>().IsTuple());
314
315//
316// @brief Device Convolution operation.
317//
318// Supports:
319// @li Forward convolution with up to 3 spatial dimentions
320// @li Input tensor in GNWC data format
321// @li Weight tensor in GKXC data format
322// @li Output tensor in GNWK data format
323//
324// 1D:
325// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
326// 2D:
327// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
328// 3D:
329// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
330//
331template <index_t NDimSpatial,
332 typename ALayout,
333 typename BLayout,
334 typename DsLayout,
335 typename ELayout,
336 typename ADataType,
337 typename BDataType,
338 typename AccDataType,
339 typename CShuffleDataType,
340 typename DsDataType,
341 typename EDataType,
342 typename AElementwiseOperation,
343 typename BElementwiseOperation,
344 typename CDEElementwiseOperation,
345 ConvolutionForwardSpecialization ConvForwardSpecialization,
346 GemmSpecialization GemmSpec,
347 index_t BlockSize,
348 index_t MPerBlock,
349 index_t NPerBlock,
350 index_t KPerBlock,
351 index_t AK1,
352 index_t BK1,
353 index_t MPerXDL,
354 index_t NPerXDL,
355 index_t MXdlPerWave,
356 index_t NXdlPerWave,
357 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
358 typename ABlockTransferThreadClusterArrangeOrder,
359 typename ABlockTransferSrcAccessOrder,
360 index_t ABlockTransferSrcVectorDim,
361 index_t ABlockTransferSrcScalarPerVector,
362 index_t ABlockTransferDstScalarPerVector_AK1,
363 index_t ABlockLdsExtraM,
364 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
365 typename BBlockTransferThreadClusterArrangeOrder,
366 typename BBlockTransferSrcAccessOrder,
367 index_t BBlockTransferSrcVectorDim,
368 index_t BBlockTransferSrcScalarPerVector,
369 index_t BBlockTransferDstScalarPerVector_BK1,
370 index_t BBlockLdsExtraN,
371 index_t CShuffleMXdlPerWavePerShuffle,
372 index_t CShuffleNXdlPerWavePerShuffle,
373 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
374 index_t CDEBlockTransferScalarPerVector_NPerBlock,
377 typename AComputeDataType =
378 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
379 Number<0>,
380 ADataType>()), // ComputeType is InputType by default (first
381 // in tuple for MultiAB), unpack if tuple was
382 // passed
383 typename BComputeDataType = AComputeDataType,
384 bool DirectLoad = false>
386 : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
387 ALayout,
388 BLayout,
389 DsLayout,
390 ELayout,
391 ADataType,
392 BDataType,
393 DsDataType,
394 EDataType,
395 AElementwiseOperation,
396 BElementwiseOperation,
397 CDEElementwiseOperation,
398 AComputeDataType,
399 BComputeDataType>
400{
403 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
404 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
405
408 static constexpr bool isMultiD = DsDataType::Size() > 0;
409 static constexpr bool isMultiABD = isMultiA && isMultiB && isMultiD;
410
411 static constexpr bool DoElementwiseBeforeCShuffle =
414
415 static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
416 static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
417 static constexpr index_t NumDTensor = DsDataType::Size();
418
419 static constexpr auto I0 = Number<0>{};
420 static constexpr auto I1 = Number<1>{};
421 static constexpr auto I2 = Number<2>{};
422 static constexpr auto I3 = Number<3>{};
423 static constexpr auto I4 = Number<4>{};
424 static constexpr auto I5 = Number<5>{};
425
426 // Generate vector size for C & Ds
428 typename uniform_sequence_gen<NumDTensor + 1,
429 CDEBlockTransferScalarPerVector_NPerBlock>::type;
430
432 ConvForwardSpecialization,
433 true /*SplitN*/,
434 ADataType,
435 EDataType>;
436
437 using ComputePtrOffset = ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>;
438
439 static constexpr auto matrix_padder =
440 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
441
443 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
444
445 static constexpr auto conv_ngchw_to_nhwgc_transformer =
447 BLayout,
448 ELayout,
449 NDimSpatial,
450 MPerBlock / ClusterLengthNPerBlock,
451 NPerBlock / ClusterLengthNPerBlock>{};
452
453 template <typename ALay>
454 static auto
456
457 {
458 namespace ctc = tensor_layout::convolution;
459 using Layout = std::conditional_t<
461 ctc::NHWGC,
462 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
463 ctc::NDHWGC,
464 ALay>>;
465
466 const auto in_gemmmraw_gemmkraw_desc =
467 conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
468
469 const auto in_gemmm_gemmk_desc =
470 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
471
472 const auto M = in_gemmm_gemmk_desc.GetLength(I0);
473 const auto K = in_gemmm_gemmk_desc.GetLength(I1);
474
475 const auto AK0 = K / AK1;
476
477 return transform_tensor_descriptor(in_gemmm_gemmk_desc,
482 }
483
484 template <typename BLay>
485 static auto
487 {
488 namespace ctc = tensor_layout::convolution;
489 using Layout = std::conditional_t<
491 ctc::GKYXC,
492 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
493 ctc::GKZYXC,
494 BLay>>;
495
496 const auto wei_gemmnraw_gemmkraw_desc =
497 conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
498
499 const auto wei_gemmn_gemmk_desc =
500 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
501
502 const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
503 const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
504
505 const auto BK0 = K / BK1;
506
507 return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
512 }
513
514 template <typename ELay>
515 static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
516
517 {
518 namespace ctc = tensor_layout::convolution;
519 using Layout = std::conditional_t<
521 ctc::NHWGK,
522 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
523 ctc::NDHWGK,
524 ELay>>;
525
526 const auto out_gemmmraw_gemmnraw_desc =
527 conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
528
529 const auto out_gemmm_gemmn_desc =
530 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
531
532 return out_gemmm_gemmn_desc;
533 }
534
535 // Shape of Ds and E must be aligned. Strides can be different.
536 // Pass e_g_n_k_wos_lengths for logical broadcast.
537 static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
538 {
539 return generate_tuple(
540 [&](auto i) {
541 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
542
543 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
544 },
546 }
547
548 // desc for problem definition
554
556 ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
557 ? 4 / sizeof(ADataType)
558 : ABlockTransferSrcScalarPerVector;
560 BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
561 ? 4 / sizeof(BDataType)
562 : BBlockTransferSrcScalarPerVector;
563
564 // Use appropriate gridwise gemm
565 template <index_t NXdlPerWave_>
569 DsLayout,
571 ADataType,
572 BDataType,
573 AccDataType,
574 CShuffleDataType,
575 DsDataType,
576 EDataType,
577 AElementwiseOperation,
578 BElementwiseOperation,
579 CDEElementwiseOperation,
580 GemmSpec,
581 BlockSize,
582 MPerBlock,
583 NPerBlock,
584 KPerBlock,
585 AK1,
586 BK1,
587 MPerXDL,
588 NPerXDL,
589 MXdlPerWave,
590 NXdlPerWave_,
591 ABlockTransferThreadClusterLengths_AK0_M_AK1,
592 ABlockTransferThreadClusterArrangeOrder,
593 ABlockTransferSrcAccessOrder,
594 ABlockTransferSrcVectorDim,
595 DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
596 ABlockTransferDstScalarPerVector_AK1,
597 false,
598 ABlockLdsExtraM,
599 BBlockTransferThreadClusterLengths_BK0_N_BK1,
600 BBlockTransferThreadClusterArrangeOrder,
601 BBlockTransferSrcAccessOrder,
602 BBlockTransferSrcVectorDim,
603 DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
604 BBlockTransferDstScalarPerVector_BK1,
605 false,
606 BBlockLdsExtraN,
607 CShuffleMXdlPerWavePerShuffle,
608 CShuffleNXdlPerWavePerShuffle,
609 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
611 BlkGemmPipeSched,
612 BlkGemmPipelineVer,
613 AComputeDataType,
614 BComputeDataType,
615 ADataType,
616 BDataType,
618 DirectLoad>;
621
622 // #undef GridwiseGemmV3TemplateParams
623
625
628 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
631 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
632
635 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
638 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
639
641
650 NPerBlock,
651 NPerBlock,
652 NPerBlock / ClusterLengthNPerBlock,
653 NPerBlock / ClusterLengthNPerBlock,
657 I1,
658 I0>;
659
668 NPerBlock,
669 NPerBlock,
670 NPerBlock / ClusterLengthNPerBlock,
671 NPerBlock / ClusterLengthNPerBlock,
675 I0,
676 I1>;
677
686 NPerBlock,
687 NPerBlock,
688 NPerBlock / ClusterLengthNPerBlock,
689 NPerBlock / ClusterLengthNPerBlock,
693 I0,
694 I1>;
695
696 // desc for blockwise copy
701
702 // Argument
703 struct Argument : public BaseArgument
704 {
705 Argument(const void* p_as,
706 const void* p_bs,
707 const std::array<const void*, NumDTensor>& p_ds,
708 void* p_e,
709 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
710 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
711 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
712 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
713 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
714 ds_g_n_k_wos_lengths,
715 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
716 ds_g_n_k_wos_strides,
717 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
718 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
719 const std::array<index_t, NDimSpatial>& conv_filter_strides,
720 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
721 const std::array<index_t, NDimSpatial>& input_left_pads,
722 const std::array<index_t, NDimSpatial>& input_right_pads,
723 const AElementwiseOperation& a_element_op,
724 const BElementwiseOperation& b_element_op,
725 const CDEElementwiseOperation& cde_element_op)
726 : p_a_grid_{},
727 p_b_grid_{},
728 p_ds_grid_{p_ds},
729 p_e_grid_{static_cast<EDataType*>(p_e)},
730 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
732 a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
733 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
735 b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
736 ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
737 ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
738 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
740 e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
741 conv_filter_strides_{conv_filter_strides},
742 conv_filter_dilations_{conv_filter_dilations},
743 input_left_pads_{input_left_pads},
744 input_right_pads_{input_right_pads},
766 a_element_op_{a_element_op},
767 b_element_op_{b_element_op},
768 cde_element_op_{cde_element_op}
769 {
770 // A/B/E Batch/N Stride
774
775 // p_as and p_bs are pointers
776 p_a_grid_ = static_cast<const ADataType*>(p_as);
777 p_b_grid_ = static_cast<const BDataType*>(p_bs);
778
779 // populate pointer, batch stride, desc for Ds
780 static_for<0, NumDTensor, 1>{}([&](auto i) {
781 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
782 // D batch stride
784 compute_ptr_offset_of_n_.BatchStrideDs_(i) =
786
787 ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_,
797
798 // D desc
800 DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
801 });
802
805
808 {
809 // Use not modified base strides
811 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
812 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
814 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
815 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
816
818 conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
819 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
821 conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
822 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
823
825 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
826 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
828 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
829 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
830
832 a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
834 b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
836 e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
837 }
838 }
839
841 {
844 {
846 a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
847 // Align to 128B
848 return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
849 }
850 else
851 {
852 return 0;
853 }
854 }
855
857 {
860 {
862 b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
863 // Align to 128B
864 return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
865 }
866 else
867 {
868 return 0;
869 }
870 }
871
873 {
876 {
878 e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
879 return sizeof(EDataType) * e_accum;
880 }
881 else
882 {
883 return 0;
884 }
885 }
886
892
893 void Print() const
894 {
895 std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
896 std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
898 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
899 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
900 }
901
902 // private:
903 // pointers (tuple if multi AB, pointer if no)
904 const ADataType* p_a_grid_;
905 const BDataType* p_b_grid_;
906 const std::array<const void*, NumDTensor> p_ds_grid_;
907 EDataType* p_e_grid_;
908
909 // for checking IsSupportedArgument()
910 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
911 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
912 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
913 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
914 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
915 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
916 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
917 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
918 std::array<index_t, NDimSpatial> conv_filter_strides_;
919 std::array<index_t, NDimSpatial> conv_filter_dilations_;
920 std::array<index_t, NDimSpatial> input_left_pads_;
921 std::array<index_t, NDimSpatial> input_right_pads_;
922
923 // tensor descriptors for problem definiton
925
928
929 // tensor descriptors for block/thread-wise copy
932
935
936 // for computing batch offset
939
940 // element-wise op
941 AElementwiseOperation a_element_op_;
942 BElementwiseOperation b_element_op_;
943 CDEElementwiseOperation cde_element_op_;
944
945 // block-to-e-tile map
948
953 };
954
955 // Invoker
956 struct Invoker : public BaseInvoker
957 {
959
960 template <typename GridwiseGemm>
961 float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
962 {
963 if(stream_config.log_level_ > 0)
964 {
965 arg.Print();
966 }
967
968 float ave_time = 0;
969
970 constexpr index_t minimum_occupancy =
971 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
972
973 const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
974 const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
975 const index_t GemmK =
976 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
977
978 const index_t num_workgroups_per_Conv_N =
980
981 index_t gdx, gdy, gdz;
982 // TODO: Do we want to support kbatch ??
983 std::tie(gdx, gdy, gdz) =
984 GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
985
986 gdy = arg.num_group_;
987 gdz = num_workgroups_per_Conv_N;
988
989 index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
990 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
991
992 const ADataType* p_a_grid = arg.p_a_grid_;
993 const BDataType* p_b_grid = arg.p_b_grid_;
994 EDataType* p_e_grid = arg.p_e_grid_;
995
998 {
1001 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1002 p_e_grid =
1005 sizeof(EDataType);
1006 }
1007
1008 typename GridwiseGemm::Argument gemm_arg{
1009 p_a_grid,
1010 p_b_grid,
1011 arg.p_ds_grid_,
1012 p_e_grid,
1013 GemmM,
1014 GemmN,
1015 GemmK,
1016 // No need to set strides, we pass descs to kernel
1017 I0,
1018 I0,
1019 {},
1020 I0,
1021 I1, // kbatch
1022 arg.a_element_op_,
1023 arg.b_element_op_,
1024 arg.cde_element_op_};
1025
1026 const auto Run = [&](const auto& kernel) {
1027 if(stream_config.flush_cache)
1028 {
1029 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
1031 gemm_arg_,
1032 stream_config.rotating_count,
1033 gemm_arg_.M * gemm_arg_.K * sizeof(ADataType),
1034 gemm_arg_.K * gemm_arg_.N * sizeof(BDataType));
1035 rotating_mem.Print();
1036
1037 auto run_flush_cache = [&]() {
1038 // flush icache
1040 // rotating mem
1041 rotating_mem.Next();
1042 };
1043
1045 stream_config,
1046 run_flush_cache,
1047 kernel,
1048 dim3(gdx, gdy, gdz),
1049 dim3(BlockSize),
1050 0,
1051 gemm_arg_,
1055 arg.e_grid_desc_m_n_,
1058 }
1059 else
1060 {
1061 ave_time += launch_and_time_kernel(stream_config,
1062 kernel,
1063 dim3(gdx, gdy, gdz),
1064 dim3(BlockSize),
1065 0,
1066 gemm_arg,
1070 arg.e_grid_desc_m_n_,
1073 }
1074 };
1075
1076 if(has_main_k_block_loop)
1077 {
1078 // Tail number always full
1079 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
1080 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
1081 {
1082 const auto kernel =
1083 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1084 ComputePtrOffset,
1085 DeviceOp::AGridDesc_AK0_M_AK1,
1086 DeviceOp::BGridDesc_BK0_N_BK1,
1087 DeviceOp::DsGridDesc_M_N,
1088 DeviceOp::EGridDesc_M_N,
1089 true,
1090 InMemoryDataOperationEnum::Set,
1091 minimum_occupancy>;
1092 Run(kernel);
1093 }
1094 // Tail number could be One to Seven
1095 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
1096 {
1097 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
1098 {
1099 const auto kernel =
1100 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1101 ComputePtrOffset,
1102 DeviceOp::AGridDesc_AK0_M_AK1,
1103 DeviceOp::BGridDesc_BK0_N_BK1,
1104 DeviceOp::DsGridDesc_M_N,
1105 DeviceOp::EGridDesc_M_N,
1106 true,
1107 InMemoryDataOperationEnum::Set,
1108 minimum_occupancy,
1109 TailNumber::One>;
1110 Run(kernel);
1111 }
1112 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
1113 {
1114 const auto kernel =
1115 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1116 ComputePtrOffset,
1117 DeviceOp::AGridDesc_AK0_M_AK1,
1118 DeviceOp::BGridDesc_BK0_N_BK1,
1119 DeviceOp::DsGridDesc_M_N,
1120 DeviceOp::EGridDesc_M_N,
1121 true,
1122 InMemoryDataOperationEnum::Set,
1123 minimum_occupancy,
1124 TailNumber::Full>;
1125 Run(kernel);
1126 }
1127
1128 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
1129 {
1130 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
1131 {
1132 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1133 GridwiseGemm,
1134 ComputePtrOffset,
1135 DeviceOp::AGridDesc_AK0_M_AK1,
1136 DeviceOp::BGridDesc_BK0_N_BK1,
1137 DeviceOp::DsGridDesc_M_N,
1138 DeviceOp::EGridDesc_M_N,
1139 true,
1140 InMemoryDataOperationEnum::Set,
1141 minimum_occupancy,
1142 TailNumber::Two>;
1143 Run(kernel);
1144 }
1145 }
1146
1147 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
1148 {
1149 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
1150 {
1151 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1152 GridwiseGemm,
1153 ComputePtrOffset,
1154 DeviceOp::AGridDesc_AK0_M_AK1,
1155 DeviceOp::BGridDesc_BK0_N_BK1,
1156 DeviceOp::DsGridDesc_M_N,
1157 DeviceOp::EGridDesc_M_N,
1158 true,
1159 InMemoryDataOperationEnum::Set,
1160 minimum_occupancy,
1161 TailNumber::Three>;
1162 Run(kernel);
1163 }
1164 }
1165
1166 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
1167 {
1168 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
1169 {
1170 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1171 GridwiseGemm,
1172 ComputePtrOffset,
1173 DeviceOp::AGridDesc_AK0_M_AK1,
1174 DeviceOp::BGridDesc_BK0_N_BK1,
1175 DeviceOp::DsGridDesc_M_N,
1176 DeviceOp::EGridDesc_M_N,
1177 true,
1178 InMemoryDataOperationEnum::Set,
1179 minimum_occupancy,
1180 TailNumber::Four>;
1181 Run(kernel);
1182 }
1183 }
1184
1185 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
1186 {
1187 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
1188 {
1189 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1190 GridwiseGemm,
1191 ComputePtrOffset,
1192 DeviceOp::AGridDesc_AK0_M_AK1,
1193 DeviceOp::BGridDesc_BK0_N_BK1,
1194 DeviceOp::DsGridDesc_M_N,
1195 DeviceOp::EGridDesc_M_N,
1196 true,
1197 InMemoryDataOperationEnum::Set,
1198 minimum_occupancy,
1199 TailNumber::Five>;
1200 Run(kernel);
1201 }
1202 }
1203
1204 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
1205 {
1206 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
1207 {
1208 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1209 GridwiseGemm,
1210 ComputePtrOffset,
1211 DeviceOp::AGridDesc_AK0_M_AK1,
1212 DeviceOp::BGridDesc_BK0_N_BK1,
1213 DeviceOp::DsGridDesc_M_N,
1214 DeviceOp::EGridDesc_M_N,
1215 true,
1216 InMemoryDataOperationEnum::Set,
1217 minimum_occupancy,
1218 TailNumber::Six>;
1219 Run(kernel);
1220 }
1221 }
1222
1223 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
1224 {
1225 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
1226 {
1227 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1228 GridwiseGemm,
1229 ComputePtrOffset,
1230 DeviceOp::AGridDesc_AK0_M_AK1,
1231 DeviceOp::BGridDesc_BK0_N_BK1,
1232 DeviceOp::DsGridDesc_M_N,
1233 DeviceOp::EGridDesc_M_N,
1234 true,
1235 InMemoryDataOperationEnum::Set,
1236 minimum_occupancy,
1237 TailNumber::Seven>;
1238 Run(kernel);
1239 }
1240 }
1241 }
1242 // Tail number could be Odd or Even
1243 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
1244 {
1245 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1246 {
1247 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds<
1248 GridwiseGemm,
1249 ComputePtrOffset,
1250 DeviceOp::AGridDesc_AK0_M_AK1,
1251 DeviceOp::BGridDesc_BK0_N_BK1,
1252 DeviceOp::DsGridDesc_M_N,
1253 DeviceOp::EGridDesc_M_N,
1254 true,
1255 InMemoryDataOperationEnum::Set,
1256 minimum_occupancy,
1257 TailNumber::Odd>;
1258 Run(kernel);
1259 }
1260 else
1261 {
1262 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds<
1263 GridwiseGemm,
1264 ComputePtrOffset,
1265 DeviceOp::AGridDesc_AK0_M_AK1,
1266 DeviceOp::BGridDesc_BK0_N_BK1,
1267 DeviceOp::DsGridDesc_M_N,
1268 DeviceOp::EGridDesc_M_N,
1269 true,
1270 InMemoryDataOperationEnum::Set,
1271 minimum_occupancy,
1272 TailNumber::Even>;
1273 Run(kernel);
1274 }
1275 }
1276 else
1277 {
1278 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1279 {
1280 const auto kernel =
1281 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1282 ComputePtrOffset,
1283 DeviceOp::AGridDesc_AK0_M_AK1,
1284 DeviceOp::BGridDesc_BK0_N_BK1,
1285 DeviceOp::DsGridDesc_M_N,
1286 DeviceOp::EGridDesc_M_N,
1287 true,
1288 InMemoryDataOperationEnum::Set,
1289 minimum_occupancy,
1290 TailNumber::Odd>;
1291 Run(kernel);
1292 }
1293 else
1294 {
1295 const auto kernel =
1296 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1297 ComputePtrOffset,
1298 DeviceOp::AGridDesc_AK0_M_AK1,
1299 DeviceOp::BGridDesc_BK0_N_BK1,
1300 DeviceOp::DsGridDesc_M_N,
1301 DeviceOp::EGridDesc_M_N,
1302 true,
1303 InMemoryDataOperationEnum::Set,
1304 minimum_occupancy,
1305 TailNumber::Even>;
1306 Run(kernel);
1307 }
1308 }
1309 }
1310 // has_main_k_block_loop
1311 else
1312 {
1313 // Tail number always 1
1314 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
1315 {
1316 const auto kernel =
1317 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1318 ComputePtrOffset,
1319 DeviceOp::AGridDesc_AK0_M_AK1,
1320 DeviceOp::BGridDesc_BK0_N_BK1,
1321 DeviceOp::DsGridDesc_M_N,
1322 DeviceOp::EGridDesc_M_N,
1323 false,
1324 InMemoryDataOperationEnum::Set,
1325 minimum_occupancy>;
1326 Run(kernel);
1327 }
1328 }
1329
1330 return ave_time;
1331 }
1332
1333 template <typename GridwiseGemm>
1334 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1335 {
1336 float avg_time = 0.f;
1337 if constexpr(!isMultiABD)
1338 {
1339 // Transpose to NGHWC layotu
1342 {
1343 const index_t a_grid_size =
1344 arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
1346 const index_t b_grid_size =
1347 arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
1349
1350 ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
1351 BDataType* p_b_out_grid =
1353 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1354
1355 auto kernel_transpose =
1369
1370 avg_time +=
1371 launch_and_time_kernel(stream_config,
1372 kernel_transpose,
1373 dim3(a_grid_size + b_grid_size),
1375 0,
1380 make_tuple(arg.p_a_grid_),
1381 make_tuple(arg.p_b_grid_),
1382 make_tuple(p_a_out_grid),
1383 make_tuple(p_b_out_grid),
1387 a_grid_size);
1388 }
1389
1390 avg_time += RunGemm<GridwiseGemm>(arg, stream_config);
1391
1392 // Transpose result back to NGCHW
1395 {
1396 const index_t grid_size =
1397 arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
1399
1400 const EDataType* p_e_in_grid =
1403 sizeof(EDataType);
1404
1405 EDataType* p_e_out_grid = arg.p_e_grid_;
1406
1407 auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
1412 Block2TileMapElementwise,
1414
1415 avg_time +=
1416 launch_and_time_kernel(stream_config,
1417 kernel_transpose,
1418 dim3(grid_size),
1419 dim3(ElementwiseBlocksize),
1420 0,
1423 make_tuple(p_e_in_grid),
1424 make_tuple(p_e_out_grid),
1427 }
1428 }
1429 return avg_time;
1430 }
1431
1433
1434 float Run(const BaseArgument* p_arg,
1435 const StreamConfig& stream_config = StreamConfig{}) override
1436 {
1437 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1438 }
1439 };
1440
1441 static bool IsSupportedArgument(const Argument& arg)
1442 {
1443 namespace ctc = tensor_layout::convolution;
1444
1445 const index_t G = arg.b_g_k_c_xs_lengths_[I0];
1446 const index_t K = arg.b_g_k_c_xs_lengths_[I1];
1447 const index_t C = arg.b_g_k_c_xs_lengths_[I2];
1448 // Move this to runtime check to align Conv instances
1449 // with Conv Multiple D instances
1450 if constexpr(isMultiABD)
1451 {
1452 return false;
1453 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1454 {
1455 std::cout << "The MultiABD is not supported!" << " In " << __FILE__ << ":"
1456 << __LINE__ << ", in function: " << __func__ << std::endl;
1457 }
1458 }
1459
1460 // check device
1461 if constexpr(DirectLoad)
1462 {
1463 if(get_device_name() != "gfx950")
1464 {
1465 return false;
1466 }
1467 }
1468
1469 if(get_device_name() == "gfx908")
1470 {
1471 // FIXME: re-enable fp64 when SWDEV-335738 is fixed
1473 {
1474 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1475 {
1476 std::cout
1477 << "On gfx908 the accumulation data type must be one of fp32 or int32!"
1478 << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1479 << std::endl;
1480 }
1481 return false;
1482 }
1483 }
1484
1486 {
1487 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1488 {
1489 std::cout << "Current device does not support xdl instructions!" << " In "
1490 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1491 << std::endl;
1492 }
1493 return false;
1494 }
1495
1498 {
1499 if(!is_tf32_supported())
1500 {
1501 return false;
1502 }
1504 {
1505 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1506 {
1507 std::cout << "ComputeDataType for A and B should be same while using TF32"
1508 << std::endl;
1509 }
1510 return false;
1511 }
1512 }
1513
1514 // check ConvolutionForwardSpecialization
1515 if constexpr(ConvForwardSpecialization ==
1517 {
1518 // check if it's 1x1, stride=1 conv
1519 for(index_t i = 0; i < NDimSpatial; ++i)
1520 {
1521 const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
1522 const index_t ConvStride = arg.conv_filter_strides_[i];
1523 const index_t LeftPad = arg.input_left_pads_[i];
1524 const index_t RightPad = arg.input_right_pads_[i];
1525
1526 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
1527 {
1528 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1529 {
1530 std::cout << "The input paramters do not align with specialization "
1531 "Filter1x1Stride1Pad0!"
1532 << " In " << __FILE__ << ":" << __LINE__
1533 << ", in function: " << __func__ << std::endl;
1534 }
1535 return false;
1536 }
1537 }
1538 }
1539 else if constexpr(ConvForwardSpecialization ==
1541 {
1542 // check if it's 1x1 conv
1543 for(index_t i = 0; i < NDimSpatial; ++i)
1544 {
1545 const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
1546 const index_t LeftPad = arg.input_left_pads_[i];
1547 const index_t RightPad = arg.input_right_pads_[i];
1548
1549 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
1550 {
1551 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1552 {
1553 std::cout
1554 << "The input paramters do not align with specialization Filter1x1Pad0!"
1555 << " In " << __FILE__ << ":" << __LINE__
1556 << ", in function: " << __func__ << std::endl;
1557 }
1558 return false;
1559 }
1560 }
1561 }
1562
1563 // check vector access of A
1564 // FIXME: layout
1571 {
1572 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
1573 {
1574 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1575 {
1576 std::cout << "[A Layout] The number of input channels is not a multiple of "
1577 "ABlockTransferSrcScalarPerVector!"
1578 << " In " << __FILE__ << ":" << __LINE__
1579 << ", in function: " << __func__ << std::endl;
1580 }
1581 return false;
1582 }
1583 }
1584 else
1585 {
1586 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1587 {
1588 std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__
1589 << ", in function: " << __func__ << std::endl;
1590 }
1591 return false;
1592 }
1593
1594 // check vector access of B
1595 // FIXME: layout
1602
1603 {
1604 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
1605 {
1606 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1607 {
1608 std::cout << "[B Layout] The number of input channels is not a multiple of "
1609 "BBlockTransferSrcScalarPerVector!"
1610 << " In " << __FILE__ << ":" << __LINE__
1611 << ", in function: " << __func__ << std::endl;
1612 }
1613 return false;
1614 }
1615 }
1616 else
1617 {
1618 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1619 {
1620 std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__
1621 << ", in function: " << __func__ << std::endl;
1622 }
1623 return false;
1624 }
1625
1628 {
1629 if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1630 {
1631 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1632 {
1633 std::cout << "[NGCHW Layout] The G * C is not a multiple of "
1634 "CDEBlockTransferScalarPerVector_NPerBlock"
1635 << " In " << __FILE__ << ":" << __LINE__
1636 << ", in function: " << __func__ << std::endl;
1637 }
1638 return false;
1639 }
1640
1641 if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1642 {
1643 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1644 {
1645 std::cout << "[NGCHW Layout] The G * K is not a multiple of "
1646 "CDEBlockTransferScalarPerVector_NPerBlock"
1647 << " In " << __FILE__ << ":" << __LINE__
1648 << ", in function: " << __func__ << std::endl;
1649 }
1650 return false;
1651 }
1652
1653 const index_t input_spatial_acum = ck::accumulate_n<index_t>(
1654 arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1655 const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1656 arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1657
1658 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1659 {
1660 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1661 {
1662 std::cout << "[NGCHW Layout] The input_spatial_acum is not a multiple of "
1663 "CDEBlockTransferScalarPerVector_NPerBlock"
1664 << " In " << __FILE__ << ":" << __LINE__
1665 << ", in function: " << __func__ << std::endl;
1666 }
1667 return false;
1668 }
1669
1670 if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1671 {
1672 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1673 {
1674 std::cout << "[NGCHW Layout] The output_spatial_acum is not a multiple of "
1675 "CDEBlockTransferScalarPerVector_NPerBlock"
1676 << " In " << __FILE__ << ":" << __LINE__
1677 << ", in function: " << __func__ << std::endl;
1678 }
1679 return false;
1680 }
1681
1682 if(!arg.p_workspace_)
1683 {
1684 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1685 {
1686 std::cout << "Warning: Workspace for "
1687 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3::Argument is not "
1688 "allocated, use SetWorkSpacePointer."
1689 << std::endl;
1690 }
1691 return false;
1692 }
1693
1694 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1695 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1696 arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
1697 {
1698 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1699 {
1700 std::cout << "[NGCHW Layout] One of the transposed vectors is exceeding 2GB "
1701 "memory size!"
1702 << " In " << __FILE__ << ":" << __LINE__
1703 << ", in function: " << __func__ << std::endl;
1704 }
1705 return false;
1706 }
1707 }
1708
1709 // check vector access of E
1716 {
1717 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1718 {
1719 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1720 {
1721 std::cout << "[E Layout] The K is not a multiple of "
1722 "CDEBlockTransferScalarPerVector_NPerBlock"
1723 << " In " << __FILE__ << ":" << __LINE__
1724 << ", in function: " << __func__ << std::endl;
1725 }
1726 return false;
1727 }
1728 }
1729 else
1730 {
1731 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1732 {
1733 std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__
1734 << ", in function: " << __func__ << std::endl;
1735 }
1736 return false;
1737 }
1738
1739 // Gridwise gemm v3 doesn't verify descriptors size
1741 {
1742 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1743 {
1744 std::cout
1745 << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!"
1746 << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1747 << std::endl;
1748 }
1749 return false;
1750 }
1751
1752 // check Gridwise GEMM
1753 const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
1754 const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
1755 const index_t GemmK =
1756 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
1757
1758 if(get_warp_size() == 64)
1759 {
1760 if constexpr(NXdlPerWave64 > 0)
1761 {
1762 typename GridwiseGemm64::Argument gemm_arg{nullptr,
1763 nullptr,
1764 {},
1765 nullptr,
1766 GemmM,
1767 GemmN,
1768 GemmK,
1769 I0,
1770 I0,
1771 {},
1772 I0,
1773 I1 /*KBatch*/,
1774 arg.a_element_op_,
1775 arg.b_element_op_,
1776 arg.cde_element_op_};
1777 return GridwiseGemm64::CheckValidity(gemm_arg);
1778 }
1779 }
1780 else
1781 {
1782 if constexpr(NXdlPerWave32 > 0)
1783 {
1784 typename GridwiseGemm32::Argument gemm_arg{nullptr,
1785 nullptr,
1786 {},
1787 nullptr,
1788 GemmM,
1789 GemmN,
1790 GemmK,
1791 I0,
1792 I0,
1793 {},
1794 I0,
1795 I1 /*KBatch*/,
1796 arg.a_element_op_,
1797 arg.b_element_op_,
1798 arg.cde_element_op_};
1799 return GridwiseGemm32::CheckValidity(gemm_arg);
1800 }
1801 }
1802
1803 return false;
1804 }
1805
1806 bool IsSupportedArgument(const BaseArgument* p_arg) override
1807 {
1808 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1809 }
1810
1811 static auto MakeArgument(
1812 const void* p_as,
1813 const void* p_bs,
1814 const std::array<const void*, NumDTensor>& p_ds,
1815 void* p_e,
1816 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1817 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1818 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1819 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1820 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1821 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1822 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1823 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1824 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1825 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1826 const std::array<index_t, NDimSpatial>& input_left_pads,
1827 const std::array<index_t, NDimSpatial>& input_right_pads,
1828 const AElementwiseOperation& a_element_op,
1829 const BElementwiseOperation& b_element_op,
1830 const CDEElementwiseOperation& cde_element_op)
1831 {
1832 return Argument{p_as,
1833 p_bs,
1834 p_ds,
1835 p_e,
1836 a_g_n_c_wis_lengths,
1837 a_g_n_c_wis_strides,
1838 b_g_k_c_xs_lengths,
1839 b_g_k_c_xs_strides,
1840 ds_g_n_k_wos_lengths,
1841 ds_g_n_k_wos_strides,
1842 e_g_n_k_wos_lengths,
1843 e_g_n_k_wos_strides,
1844 conv_filter_strides,
1845 conv_filter_dilations,
1846 input_left_pads,
1847 input_right_pads,
1848 a_element_op,
1849 b_element_op,
1850 cde_element_op};
1851 }
1852
1853 static auto
1854 MakeArgument(const void* p_as,
1855 const void* p_bs,
1856 const std::array<const void*, NumDTensor>& p_ds,
1857 void* p_e,
1858 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1859 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1860 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1861 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1862 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1863 ds_g_n_k_wos_lengths,
1864 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1865 ds_g_n_k_wos_strides,
1866 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1867 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1868 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1869 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1870 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1871 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1872 const AElementwiseOperation& a_element_op,
1873 const BElementwiseOperation& b_element_op,
1874 const CDEElementwiseOperation& cde_element_op)
1875 {
1876 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1877 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1878 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1879 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1880 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
1881 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
1882 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1883 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1884 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
1885 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
1886 std::array<index_t, NDimSpatial> input_left_pads_i32;
1887 std::array<index_t, NDimSpatial> input_right_pads_i32;
1888
1889 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
1890 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
1891 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
1892 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
1893 for(index_t d = 0; d < NumDTensor; d++)
1894 {
1895 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1896 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1897 }
1898 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1899 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1900 array_convert(conv_filter_strides_i32, conv_filter_strides);
1901 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1902 array_convert(input_left_pads_i32, input_left_pads);
1903 array_convert(input_right_pads_i32, input_right_pads);
1904
1905 return Argument{p_as,
1906 p_bs,
1907 p_ds,
1908 p_e,
1909 a_g_n_c_wis_lengths_i32,
1910 a_g_n_c_wis_strides_i32,
1911 b_g_k_c_xs_lengths_i32,
1912 b_g_k_c_xs_strides_i32,
1913 ds_g_n_k_wos_lengths_i32,
1914 ds_g_n_k_wos_strides_i32,
1915 e_g_n_k_wos_lengths_i32,
1916 e_g_n_k_wos_strides_i32,
1917 conv_filter_strides_i32,
1918 conv_filter_dilations_i32,
1919 input_left_pads_i32,
1920 input_right_pads_i32,
1921 a_element_op,
1922 b_element_op,
1923 cde_element_op};
1924 }
1925
1926 static auto MakeInvoker() { return Invoker{}; }
1927
1928 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1929 const void* p_a,
1930 const void* p_b,
1931 const std::array<const void*, NumDTensor>& p_ds,
1932 void* p_e,
1933 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1934 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1935 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1936 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1937 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1938 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1939 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1940 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1941 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1942 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1943 const std::array<index_t, NDimSpatial>& input_left_pads,
1944 const std::array<index_t, NDimSpatial>& input_right_pads,
1945 const AElementwiseOperation& a_element_op,
1946 const BElementwiseOperation& b_element_op,
1947 const CDEElementwiseOperation& cde_element_op) override
1948 {
1949 return std::make_unique<Argument>(p_a,
1950 p_b,
1951 p_ds,
1952 p_e,
1953 a_g_n_c_wis_lengths,
1954 a_g_n_c_wis_strides,
1955 b_g_k_c_xs_lengths,
1956 b_g_k_c_xs_strides,
1957 ds_g_n_k_wos_lengths,
1958 ds_g_n_k_wos_strides,
1959 e_g_n_k_wos_lengths,
1960 e_g_n_k_wos_strides,
1961 conv_filter_strides,
1962 conv_filter_dilations,
1963 input_left_pads,
1964 input_right_pads,
1965 a_element_op,
1966 b_element_op,
1967 cde_element_op);
1968 }
1969
1970 std::unique_ptr<BaseArgument>
1971 MakeArgumentPointer(const void* p_a,
1972 const void* p_b,
1973 const std::array<const void*, NumDTensor>& p_ds,
1974 void* p_e,
1975 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1976 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1977 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1978 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1979 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1980 ds_g_n_k_wos_lengths,
1981 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1982 ds_g_n_k_wos_strides,
1983 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1984 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1985 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1986 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1987 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1988 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1989 const AElementwiseOperation& a_element_op,
1990 const BElementwiseOperation& b_element_op,
1991 const CDEElementwiseOperation& cde_element_op) override
1992 {
1993 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1994 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1995 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1996 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1997 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
1998 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
1999 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
2000 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
2001 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
2002 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
2003 std::array<index_t, NDimSpatial> input_left_pads_i32;
2004 std::array<index_t, NDimSpatial> input_right_pads_i32;
2005
2006 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
2007 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
2008 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
2009 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
2010 for(index_t d = 0; d < NumDTensor; d++)
2011 {
2012 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
2013 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
2014 }
2015 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
2016 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
2017 array_convert(conv_filter_strides_i32, conv_filter_strides);
2018 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
2019 array_convert(input_left_pads_i32, input_left_pads);
2020 array_convert(input_right_pads_i32, input_right_pads);
2021
2022 return std::make_unique<Argument>(p_a,
2023 p_b,
2024 p_ds,
2025 p_e,
2026 a_g_n_c_wis_lengths_i32,
2027 a_g_n_c_wis_strides_i32,
2028 b_g_k_c_xs_lengths_i32,
2029 b_g_k_c_xs_strides_i32,
2030 ds_g_n_k_wos_lengths_i32,
2031 ds_g_n_k_wos_strides_i32,
2032 e_g_n_k_wos_lengths_i32,
2033 e_g_n_k_wos_strides_i32,
2034 conv_filter_strides_i32,
2035 conv_filter_dilations_i32,
2036 input_left_pads_i32,
2037 input_right_pads_i32,
2038 a_element_op,
2039 b_element_op,
2040 cde_element_op);
2041 }
2042
2043 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
2044 {
2045 return std::make_unique<Invoker>(Invoker{});
2046 }
2047
2048 std::string GetTypeString() const override
2049 {
2050 auto str = std::stringstream();
2051
2052 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
2055
2056 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
2062
2063 // clang-format off
2064 str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
2065
2066 if constexpr(DirectLoad) {
2067 str << "_DirectLoad";
2068 }
2069
2070 str << "<"
2071 << BlockSize << ", "
2072 << MPerBlock << ", "
2073 << NPerBlock << ", "
2074 << KPerBlock << ", "
2075 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
2076 << MPerXDL << ", "
2077 << NPerXDL << ", "
2078 << MXdlPerWave << ", "
2079 << NXdlPerWave << ", "
2080 << ABlockTransferSrcScalarPerVector << ", "
2081 << BBlockTransferSrcScalarPerVector << ", "
2082 << CDEBlockTransferScalarPerVector_NPerBlock << ", "
2083 << CShuffleMXdlPerWavePerShuffle << ", "
2084 << CShuffleNXdlPerWavePerShuffle << ", "
2085 << "BlkGemmPipelineScheduler: "
2086 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
2087 << "BlkGemmPipelineVersion: "
2088 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
2089 << ">";
2090 // clang-format on
2091
2092 return str.str();
2093 }
2094
2095#ifdef CK_EXPERIMENTAL_BUILDER
2096 std::string GetInstanceString() const override
2097 {
2098 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
2099 "Specialization of instance_traits not found. Please check that a "
2100 "specialization exists in file "
2101 "ck_tile/builder/reflect/"
2102 "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp "
2103 "for the given template parameters.");
2104 return ck_tile::reflect::instance_string<DeviceOp>();
2105 }
2106#endif
2107
2108 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
2109 {
2110 auto arg = dynamic_cast<const Argument*>(p_arg);
2111 if(arg)
2112 {
2113 return arg->GetWorkspaceSizeBytes();
2114 }
2115 else
2116 throw std::runtime_error(
2117 "The argument pointer is not an object of "
2118 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
2119 }
2120
2122 void* p_workspace,
2123 const StreamConfig& = StreamConfig{}) const override
2124 {
2125 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
2126 if(p_arg_)
2127 {
2128 p_arg_->p_workspace_ = p_workspace;
2129 }
2130 else
2131 throw std::runtime_error(
2132 "The argument pointer is not an object of "
2133 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
2134 }
2135};
2136
2137} // namespace device
2138} // namespace tensor_operation
2139} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
__host__ bool AreDescriptorsSmallerThan2GB() const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:375
Definition transform_conv_ngchw_to_nhwgc.hpp:31
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:704
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:941
Argument(const void *p_as, const void *p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:705
ComputePtrOffset compute_ptr_offset_of_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:938
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:887
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:930
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:947
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:933
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:949
const std::array< const void *, NumDTensor > p_ds_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:906
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:914
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:920
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:942
ComputePtrOffset compute_ptr_offset_of_groups_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:937
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:905
NHWGCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:950
GKCYXTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:951
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:904
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:916
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:918
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:947
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:943
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:913
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:921
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:856
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:912
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:915
NGCHWTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:949
GKYXCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:952
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:910
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:907
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:950
index_t num_group_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:924
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:919
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:926
void Print() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:893
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:917
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:840
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:931
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:911
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:934
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:872
index_t conv_N_per_block_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:927
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:946
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:957
float RunGemm(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:961
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:958
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1334
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1434
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:400
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I1, I0 > GridwiseElementwiseInputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:642
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:636
static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:422
static auto MakeArgument(const void *p_as, const void *p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1811
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1441
static constexpr bool isMultiABD
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:409
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2108
static constexpr bool isMultiB
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:407
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >( dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:697
static constexpr index_t NumBTensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:416
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:445
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:559
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:486
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:550
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:555
GridwiseElementwise< Tuple< NHWGCTransposeDescType >, Tuple< NGCHWTransposeDescType >, Tuple< const EDataType * >, Tuple< EDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseOutputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:678
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:619
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:439
static auto MakeArgument(const void *p_as, const void *p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1854
static constexpr index_t ElementwiseBlocksize
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:640
BlockToCTileMap_M00_N0_M01Adapt< NPerBlock, NPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:624
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >( dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:699
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:549
GridwiseElementwise< Tuple< GKCYXTransposeDescType >, Tuple< GKYXCTransposeDescType >, Tuple< const BDataType * >, Tuple< BDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< 1 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:660
static constexpr auto I5
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:424
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2043
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:552
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:620
static constexpr index_t NumATensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:415
GridwiseGemmMultiD_xdl_cshuffle_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, DsLayout, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, DirectLoad ? ABlockTransferSrcScalarPerVectorAligned :ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, DirectLoad ? BBlockTransferSrcScalarPerVectorAligned :BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, AComputeDataType, BComputeDataType, ADataType, BDataType, DoElementwiseBeforeCShuffle, DirectLoad > GridwiseGemmBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:566
static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:420
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:629
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:455
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1928
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization, true, ADataType, EDataType > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:431
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2048
static constexpr bool DoElementwiseBeforeCShuffle
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:411
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 DeviceOp
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:401
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:537
typename uniform_sequence_gen< NumDTensor+1, CDEBlockTransferScalarPerVector_NPerBlock >::type CDEBlockTransferScalarPerVectors
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:427
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > ComputePtrOffset
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:437
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1971
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:626
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1926
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:442
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:515
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2121
static constexpr bool isMultiD
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:408
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:403
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:417
static constexpr bool isMultiA
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:406
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1806
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:633
static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:421
static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:419
static constexpr auto I4
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:423
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:404
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition utility/sequence.hpp:289
Definition flush_cache.hpp:299
#define CK_ENV(name)
Definition utility/env.hpp:129