device_grouped_conv_bwd_weight_explicit_xdl.hpp Source File

device_grouped_conv_bwd_weight_explicit_xdl.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_explicit_xdl.hpp Source File
device_grouped_conv_bwd_weight_explicit_xdl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9
11
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
24template <ck::index_t NDimSpatial,
25 typename InLayout,
26 typename WeiLayout,
27 typename OutLayout,
28 typename InDataType,
29 typename WeiDataType,
30 typename OutDataType,
31 typename InElementwiseOperation,
32 typename WeiElementwiseOperation,
33 typename OutElementwiseOperation,
34 typename DeviceGemmV3Op>
36 : public DeviceGroupedConvBwdWeight<NDimSpatial,
37 InLayout,
38 WeiLayout,
39 OutLayout,
40 InDataType,
41 WeiDataType,
42 OutDataType,
43 InElementwiseOperation,
44 WeiElementwiseOperation,
45 OutElementwiseOperation>
46{
50
51 static constexpr auto I0 = Number<0>{};
52 static constexpr auto I1 = Number<1>{};
53 static constexpr auto I2 = Number<2>{};
54
55 static constexpr bool IsTwoStageNeeded =
56 sizeof(WeiDataType) % 4 != 0 &&
57 DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
58
60 using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
61
62 static constexpr index_t ElementwiseBlockSize = 256;
63 static constexpr index_t ElemsPerBlock = 256;
64
65 static auto GetElementwiseCGridDesc(index_t merged_filter_dims)
66 {
67 const auto padd_size = merged_filter_dims % ElemsPerBlock == 0
68 ? 0
69 : ElemsPerBlock - merged_filter_dims % ElemsPerBlock;
70 const auto desc = make_naive_tensor_descriptor_packed(make_tuple(I1, merged_filter_dims));
72 desc,
74 make_right_pad_transform(merged_filter_dims, padd_size)),
77 }
78
86 WeiElementwiseOperation,
88 I1,
90 I1,
95 I1,
96 I1>;
97
98 struct Argument : public BaseArgument
99 {
100 using GemmArgument = typename DeviceGemmV3Op::Argument;
101
102 Argument(const InDataType* p_in_grid,
103 WeiDataType* p_wei_grid,
104 const OutDataType* p_out_grid,
105 const std::array<index_t, NDimSpatial + 3>&, // input
106 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
107 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
108 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
109 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
110 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
111 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
112 const std::array<ck::index_t, NDimSpatial>&,
113 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
114 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
115 InElementwiseOperation in_element_op,
116 WeiElementwiseOperation wei_element_op,
117 OutElementwiseOperation out_element_op,
118 ck::index_t split_k)
120 conv_filter_strides_{conv_filter_strides},
121 input_left_pads_{input_left_pads},
122 input_right_pads_{input_right_pads},
123 p_wei_grid_{p_wei_grid}
124 {
125 constexpr index_t spatial_offset = 3;
126 const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
127 end(a_g_n_k_wos_lengths),
128 index_t{1},
129 std::multiplies<>{});
130 const index_t M = e_g_k_c_xs_lengths[I1];
131 const index_t N = e_g_k_c_xs_lengths[I2];
132 const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
133
134 const index_t StrideOut = a_g_n_k_wos_strides[spatial_offset + NDimSpatial - 1];
135 const index_t StrideIn = b_g_n_c_wis_strides[spatial_offset + NDimSpatial - 1];
136 const index_t StrideWei = e_g_k_c_xs_strides[I1];
137 const index_t StrideBatchOut = a_g_n_k_wos_strides[I0];
138 const index_t StrideBatchIn = b_g_n_c_wis_strides[I0];
139 const index_t StrideBatchWei = e_g_k_c_xs_strides[I0];
140
141 const index_t BatchSize = a_g_n_k_wos_lengths[I0];
142
143 std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
144 end(e_g_k_c_xs_lengths),
146
147 if constexpr(IsTwoStageNeeded)
148 {
149 if(split_k < 0)
150 {
151 const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
152 index_t gdx, gdy, gdz;
153 std::tie(gdx, gdy, gdz) =
154 DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
155 const index_t grid_size = gdx * gdy * gdz;
156 split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
157 }
158 else
159 {
160 split_k_ = split_k;
161 }
162 }
163 else
164 {
165#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
166 if(split_k < 0)
167 {
168 const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
169 index_t gdx, gdy, gdz;
170 std::tie(gdx, gdy, gdz) =
171 DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
172 const index_t grid_size = gdx * gdy * gdz;
173 split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
174 }
175 else
176#endif
177 {
178 split_k_ = split_k;
179 }
180 }
181
182 if constexpr(IsTwoStageNeeded)
183 {
184 const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
185 end(e_g_k_c_xs_lengths),
186 index_t{1},
187 std::multiplies<>{});
188 elementwise_desc_ = GetElementwiseCGridDesc(merged_filter_dims);
190 // Check if stride to last dimension is product of all other dimensions. Then it is
191 // packed.
193 e_g_k_c_xs_strides[0] == (merged_filter_dims / e_g_k_c_xs_lengths[0]);
194
195 // Data type is modified during launch. It is checked in IsSupported if user
196 // allocated workspace
197 explicit_gemm_args = GemmArgument{p_out_grid,
198 p_in_grid,
199 {},
200 static_cast<TwoStageIntermediateType*>(nullptr),
201 M,
202 N,
203 K,
204 StrideOut,
205 StrideIn,
206 {},
207 StrideWei,
208 StrideBatchOut,
209 StrideBatchIn,
210 {},
211 StrideBatchWei,
212 BatchSize,
213 out_element_op,
214 in_element_op,
215 wei_element_op,
216 split_k_};
217 }
218 else
219 {
220 explicit_gemm_args = GemmArgument{p_out_grid,
221 p_in_grid,
222 {},
223 p_wei_grid,
224 M,
225 N,
226 K,
227 StrideOut,
228 StrideIn,
229 {},
230 StrideWei,
231 StrideBatchOut,
232 StrideBatchIn,
233 {},
234 StrideBatchWei,
235 BatchSize,
236 out_element_op,
237 in_element_op,
238 wei_element_op,
239 split_k_};
240 }
241 }
242
244 {
245 if constexpr(IsTwoStageNeeded)
246 {
247 return sizeof(TwoStageIntermediateType) * elementwise_desc_.GetElementSpaceSize();
248 }
249 else
250 {
251 return 0;
252 }
253 }
254
255 std::size_t GetWorkspaceSizeBytes() const
256 {
257 if constexpr(IsTwoStageNeeded)
258 {
260 }
261 else
262 {
263 return 0;
264 }
265 }
266
268 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
269 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
270 const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
271 const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
272 WeiDataType* p_wei_grid_;
277 };
278
279 // Invoker
280 struct Invoker : public BaseInvoker
281 {
283 using GemmArgument = typename DeviceGemmV3Op::Argument;
284
285 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
286 {
287 if constexpr(IsTwoStageNeeded)
288 {
289 // Modify to use workspace as output
290 GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
291 explicit_gemm_args_with_workspace.p_c_grid =
292 static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
293 float avg_time =
294 explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
295 const index_t grid_size =
296 arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.elementwise_desc_);
303 WeiElementwiseOperation>;
304
305 avg_time += launch_and_time_kernel(
306 stream_config,
307 kernel,
308 dim3(grid_size),
310 0,
313 make_tuple(static_cast<const TwoStageIntermediateType*>(arg.p_workspace_)),
317 return avg_time;
318 }
319 else
320 {
321 return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
322 }
323 }
324
325 float Run(const BaseArgument* p_arg,
326 const StreamConfig& stream_config = StreamConfig{}) override
327 {
328 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
329 }
330
331 typename DeviceGemmV3Op::Invoker explicit_gemm_op;
332 };
333
334 static constexpr bool IsValidCompilationParameter()
335 {
336 // TODO: properly implement this check
337 return true;
338 }
339
340 static bool IsSupportedArgument(const Argument& arg)
341 {
342#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
343 if constexpr(!IsTwoStageNeeded)
344 {
345 if(arg.split_k_ < 0)
346 {
347 return false;
348 }
349 }
350#endif
351
352 if constexpr(NDimSpatial == 2)
353 {
355 {
356 return false;
357 }
358 }
359 else if constexpr(NDimSpatial == 3)
360 {
362 {
363 return false;
364 }
365 }
366 else
367 {
368 return false;
369 }
370
371 // check if it's 1x1, stride=1 pad = 0 conv
372 for(int i = 0; i < NDimSpatial; i++)
373 {
374 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
375 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
376 {
377 return false;
378 }
379 }
380 if constexpr(IsTwoStageNeeded)
381 {
382 if(!arg.is_filter_data_packed)
383 {
384 return false;
385 }
386 // Check this here, it allows to use other instances from factory even
387 // if workspace is not allocated
388 if(!arg.p_workspace_)
389 {
390 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
391 {
392 std::cout << "Warning: Workspace for "
393 "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
394 "allocated, use SetWorkSpacePointer."
395 << std::endl;
396 }
397 return false;
398 }
399 }
400 // Gridwise GEMM size
401 return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
402 }
403
404 bool IsSupportedArgument(const BaseArgument* p_arg) override
405 {
406 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
407 }
408
409 static auto
410 MakeArgument(const InDataType* p_in_grid,
411 WeiDataType* p_wei_grid,
412 const OutDataType* p_out_grid,
413 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
414 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
415 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
416 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
417 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
418 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
419 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
420 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
421 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
422 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
423 InElementwiseOperation in_element_op,
424 WeiElementwiseOperation wei_element_op,
425 OutElementwiseOperation out_element_op,
426 const ck::index_t split_k)
427 {
428 return Argument{p_in_grid,
429 p_wei_grid,
430 p_out_grid,
431 b_g_n_c_wis_lengths, // input
432 b_g_n_c_wis_strides,
433 e_g_k_c_xs_lengths, // weight
434 e_g_k_c_xs_strides,
435 a_g_n_k_wos_lengths, // output
436 a_g_n_k_wos_strides,
437 conv_filter_strides,
438 conv_filter_dilations,
439 input_left_pads,
440 input_right_pads,
441 in_element_op,
442 wei_element_op,
443 out_element_op,
444 split_k};
445 }
446
447 static auto MakeInvoker() { return Invoker{}; }
448
449 std::unique_ptr<BaseArgument>
450 MakeArgumentPointer(const void* p_in_grid,
451 void* p_wei_grid,
452 const void* p_out_grid,
453 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
454 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
455 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
456 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
457 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
458 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
459 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
460 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
461 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
462 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
463 InElementwiseOperation in_element_op,
464 WeiElementwiseOperation wei_element_op,
465 OutElementwiseOperation out_element_op,
466 const ck::index_t split_k) override
467 {
468 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
469 static_cast<WeiDataType*>(p_wei_grid),
470 static_cast<const OutDataType*>(p_out_grid),
471 b_g_n_c_wis_lengths, // input
472 b_g_n_c_wis_strides,
473 e_g_k_c_xs_lengths, // weight
474 e_g_k_c_xs_strides,
475 a_g_n_k_wos_lengths, // output
476 a_g_n_k_wos_strides,
477 conv_filter_strides,
478 conv_filter_dilations,
479 input_left_pads,
480 input_right_pads,
481 in_element_op,
482 wei_element_op,
483 out_element_op,
484 split_k);
485 }
486
487 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
488 {
489 return std::make_unique<Invoker>(Invoker{});
490 }
491
492 std::string GetTypeString() const override
493 {
494 auto str = std::stringstream();
495
496 // clang-format off
497 str << "DeviceGroupedConvBwdWeight_Explicit_Xdl"
498 << "<" << DeviceGemmV3Op{}.GetTypeString() << ">";
499 // clang-format on
500
501 return str.str();
502 }
503 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
504 {
505 auto arg = dynamic_cast<const Argument*>(p_arg);
506 if(arg)
507 {
508 return arg->GetWorkspaceSizeBytes();
509 }
510 else
511 throw std::runtime_error(
512 "The argument pointer is not an object of "
513 "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
514 }
515
517 void* p_workspace,
518 const StreamConfig& = StreamConfig{}) const override
519 {
520 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
521 if(p_arg_)
522 {
523 p_arg_->p_workspace_ = p_workspace;
524 }
525 else
526 throw std::runtime_error(
527 "The argument pointer is not an object of "
528 "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
529 }
530};
531
532} // namespace device
533} // namespace tensor_operation
534} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:99
std::array< ck::index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:268
CElementwiseGridDesc elementwise_desc_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:274
Block2TileMapElementwise elementwise_block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:275
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:269
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:243
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:102
ck::index_t split_k_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:276
WeiDataType * p_wei_grid_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:272
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:270
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:255
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:271
GemmArgument explicit_gemm_args
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:267
bool is_filter_data_packed
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:273
typename DeviceGemmV3Op::Argument GemmArgument
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:100
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:281
DeviceGemmV3Op::Invoker explicit_gemm_op
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:331
typename DeviceGemmV3Op::Argument GemmArgument
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:283
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:285
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:325
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:282
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:46
static constexpr index_t ElemsPerBlock
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:63
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:447
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:404
GridwiseElementwise< Tuple< CElementwiseGridDesc >, Tuple< CElementwiseGridDesc >, Tuple< const float * >, Tuple< WeiDataType * >, Block2TileMapElementwise, WeiElementwiseOperation, ElementwiseBlockSize, I1, ElemsPerBlock, I1, ElemsPerBlock/ElementwiseBlockSize, Sequence< 0, 1 >, Sequence< 1 >, Sequence< 1 >, I1, I1 > GridwiseElementwiseCast
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:81
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:410
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:334
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:51
static constexpr index_t ElementwiseBlockSize
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:62
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:340
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:516
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:52
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:53
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:487
typename DeviceGemmV3Op::CDataType_ TwoStageIntermediateType
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:60
static constexpr bool IsTwoStageNeeded
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:55
BlockToCTileMap_M00_N0_M01Adapt< 1, ElemsPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:80
remove_cvref_t< decltype(GetElementwiseCGridDesc(I1))> CElementwiseGridDesc
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:79
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:492
static auto GetElementwiseCGridDesc(index_t merged_filter_dims)
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:65
DeviceGroupedConvBwdWeight_Explicit_Xdl DeviceOp
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:59
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:503
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_explicit_xdl.hpp:450
Definition device_grouped_conv_bwd_weight.hpp:29
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129