transform_conv_ngchw_to_nhwgc.hpp Source File

transform_conv_ngchw_to_nhwgc.hpp Source File#

Composable Kernel: transform_conv_ngchw_to_nhwgc.hpp Source File
transform_conv_ngchw_to_nhwgc.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13namespace tensor_operation {
14
15/*
16 * Transform Convolution NGCHW to NHWGC. We transform [N, G, C, H, W] tensor
17 * descriptor to [N * G * C, H * W] (input or output image). The first
18 * dimension is store dimension, the second one is load dimension. For
19 * NHWGC to NGCHW load and store are reverted. For weight we transform
20 * [G, K, C, Y, X] to [G * K * Y * X, C]. First dim is load dimension,
21 * second dim is store dimension.
22 */
23
24template <typename ALayout,
25 typename BLayout,
26 typename ELayout,
27 index_t NDimSpatial,
28 index_t MPerThread,
29 index_t NPerThread>
31{
32 static constexpr auto I0 = Number<0>{};
33 static constexpr auto I1 = Number<1>{};
34 static constexpr auto I2 = Number<2>{};
35 static constexpr auto I3 = Number<3>{};
36 static constexpr auto I4 = Number<4>{};
37 static constexpr auto I5 = Number<5>{};
38
39 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
40 static auto
41 MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
42 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
43 const index_t split_n_size = 1)
44 {
45 const index_t& G = g_n_c_wis_lengths[I0];
46 const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
47 const index_t& C = g_n_c_wis_lengths[I2];
48 const index_t& Wi = g_n_c_wis_lengths[I3];
49
50 const index_t& GStride = g_n_c_wis_strides[I0];
51 const index_t& NStride = g_n_c_wis_strides[I1];
52 const index_t& CStride = g_n_c_wis_strides[I2];
53 const index_t& WiStride = g_n_c_wis_strides[I3];
54
55 const auto desc = make_naive_tensor_descriptor(
56 make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
57 const auto merged_desc =
64 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
65 }
66
67 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
68 static auto
69 MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
70 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
71 const index_t split_n_size = 1)
72 {
73 const index_t& G = g_n_c_wis_lengths[I0];
74 const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
75 const index_t& C = g_n_c_wis_lengths[I2];
76 const index_t& Wi = g_n_c_wis_lengths[I3];
77
78 const index_t& NStride = g_n_c_wis_strides[I1];
79 const index_t WiStride = G * C;
80 const index_t GStride = C;
81 const index_t CStride = 1;
82
83 const auto desc = make_naive_tensor_descriptor(
84 make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
85 const auto merged_desc =
92 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
93 }
94
95 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
96 static auto
97 MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
98 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
99 const index_t split_n_size = 1)
100 {
101 const index_t& G = g_n_c_wis_lengths[I0];
102 const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
103 const index_t& C = g_n_c_wis_lengths[I2];
104 const index_t& Hi = g_n_c_wis_lengths[I3];
105 const index_t& Wi = g_n_c_wis_lengths[I4];
106
107 const index_t& GStride = g_n_c_wis_strides[I0];
108 const index_t& NStride = g_n_c_wis_strides[I1];
109 const index_t& CStride = g_n_c_wis_strides[I2];
110 const index_t& HiStride = g_n_c_wis_strides[I3];
111 const index_t& WiStride = g_n_c_wis_strides[I4];
112
113 const auto desc = make_naive_tensor_descriptor(
114 make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
115 const auto merged_desc =
122 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
123 }
124
125 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
126 static auto
127 MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
128 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
129 const index_t split_n_size = 1)
130 {
131 const index_t& G = g_n_c_wis_lengths[I0];
132 const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
133 const index_t& C = g_n_c_wis_lengths[I2];
134 const index_t& Hi = g_n_c_wis_lengths[I3];
135 const index_t& Wi = g_n_c_wis_lengths[I4];
136
137 const index_t& NStride = g_n_c_wis_strides[I1];
138 const index_t HiStride = Wi * G * C;
139 const index_t WiStride = G * C;
140 const index_t GStride = C;
141 const index_t CStride = 1;
142
143 const auto desc = make_naive_tensor_descriptor(
144 make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
145 const auto merged_desc =
152 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
153 }
154
155 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
156 static auto
157 MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
158 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
159 const index_t split_n_size = 1)
160 {
161 const index_t& G = g_n_c_wis_lengths[I0];
162 const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
163 const index_t& C = g_n_c_wis_lengths[I2];
164 const index_t& Di = g_n_c_wis_lengths[I3];
165 const index_t& Hi = g_n_c_wis_lengths[I4];
166 const index_t& Wi = g_n_c_wis_lengths[I5];
167
168 const index_t& GStride = g_n_c_wis_strides[I0];
169 const index_t& NStride = g_n_c_wis_strides[I1];
170 const index_t& CStride = g_n_c_wis_strides[I2];
171 const index_t& DiStride = g_n_c_wis_strides[I3];
172 const index_t& HiStride = g_n_c_wis_strides[I4];
173 const index_t& WiStride = g_n_c_wis_strides[I5];
174
175 const auto desc = make_naive_tensor_descriptor(
176 make_tuple(N, G, C, Di, Hi, Wi),
177 make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
178 const auto merged_desc =
181 make_merge_transform(make_tuple(Di, Hi, Wi))),
185 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
186 }
187
188 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
189 static auto
190 MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
191 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
192 const index_t split_n_size = 1)
193 {
194 const index_t& G = g_n_c_wis_lengths[I0];
195 const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
196 const index_t& C = g_n_c_wis_lengths[I2];
197 const index_t& Di = g_n_c_wis_lengths[I3];
198 const index_t& Hi = g_n_c_wis_lengths[I4];
199 const index_t& Wi = g_n_c_wis_lengths[I5];
200
201 const index_t& NStride = g_n_c_wis_strides[I1];
202 const index_t DiStride = Hi * Wi * G * C;
203 const index_t HiStride = Wi * G * C;
204 const index_t WiStride = G * C;
205 const index_t GStride = C;
206 const index_t CStride = 1;
207
208 const auto desc = make_naive_tensor_descriptor(
209 make_tuple(N, G, C, Di, Hi, Wi),
210 make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
211 const auto merged_desc =
214 make_merge_transform(make_tuple(Di, Hi, Wi))),
218 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
219 }
220
221 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
222 static auto
223 MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
224 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
225 {
226 const index_t& G = g_k_c_wis_lengths[I0];
227 const index_t& K = g_k_c_wis_lengths[I1];
228 const index_t& C = g_k_c_wis_lengths[I2];
229 const index_t& X = g_k_c_wis_lengths[I3];
230
231 const index_t& GStride = g_k_c_wis_strides[I0];
232 const index_t& KStride = g_k_c_wis_strides[I1];
233 const index_t& CStride = g_k_c_wis_strides[I2];
234 const index_t& XStride = g_k_c_wis_strides[I3];
235
236 const auto desc = make_naive_tensor_descriptor(
237 make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride));
238 const auto merged_desc = transform_tensor_descriptor(
239 desc,
244 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
245 }
246
247 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
248 static auto
249 MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
250 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
251 {
252 const index_t& G = g_k_c_wis_lengths[I0];
253 const index_t& K = g_k_c_wis_lengths[I1];
254 const index_t& C = g_k_c_wis_lengths[I2];
255 const index_t& X = g_k_c_wis_lengths[I3];
256
257 const index_t& GStride = g_k_c_wis_strides[I0];
258 const index_t KStride = g_k_c_wis_strides[I1];
259 const index_t CStride = 1;
260 const index_t XStride = C;
261
262 const auto desc = make_naive_tensor_descriptor(
263 make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride));
264 const auto merged_desc = transform_tensor_descriptor(
265 desc,
270 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
271 }
272
273 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
274 static auto
275 MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
276 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
277 {
278 const index_t& G = g_k_c_wis_lengths[I0];
279 const index_t& K = g_k_c_wis_lengths[I1];
280 const index_t& C = g_k_c_wis_lengths[I2];
281 const index_t& Y = g_k_c_wis_lengths[I3];
282 const index_t& X = g_k_c_wis_lengths[I4];
283
284 const index_t& GStride = g_k_c_wis_strides[I0];
285 const index_t& KStride = g_k_c_wis_strides[I1];
286 const index_t& CStride = g_k_c_wis_strides[I2];
287 const index_t& YStride = g_k_c_wis_strides[I3];
288 const index_t& XStride = g_k_c_wis_strides[I4];
289
290 const auto desc = make_naive_tensor_descriptor(
291 make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride));
292 const auto merged_desc =
299 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
300 }
301
302 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
303 static auto
304 MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
305 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
306 {
307 const index_t& G = g_k_c_wis_lengths[I0];
308 const index_t& K = g_k_c_wis_lengths[I1];
309 const index_t& C = g_k_c_wis_lengths[I2];
310 const index_t& Y = g_k_c_wis_lengths[I3];
311 const index_t& X = g_k_c_wis_lengths[I4];
312
313 const index_t& GStride = g_k_c_wis_strides[I0];
314 const index_t KStride = g_k_c_wis_strides[I1];
315 const index_t CStride = 1;
316 const index_t YStride = X * C;
317 const index_t XStride = C;
318
319 const auto desc = make_naive_tensor_descriptor(
320 make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride));
321 const auto merged_desc =
328 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
329 }
330
331 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
332 static auto
333 MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
334 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
335 {
336 const index_t& G = g_k_c_wis_lengths[I0];
337 const index_t& K = g_k_c_wis_lengths[I1];
338 const index_t& C = g_k_c_wis_lengths[I2];
339 const index_t& Z = g_k_c_wis_lengths[I3];
340 const index_t& Y = g_k_c_wis_lengths[I4];
341 const index_t& X = g_k_c_wis_lengths[I5];
342
343 const index_t& GStride = g_k_c_wis_strides[I0];
344 const index_t& KStride = g_k_c_wis_strides[I1];
345 const index_t& CStride = g_k_c_wis_strides[I2];
346 const index_t& ZStride = g_k_c_wis_strides[I3];
347 const index_t& YStride = g_k_c_wis_strides[I4];
348 const index_t& XStride = g_k_c_wis_strides[I5];
349
350 const auto desc = make_naive_tensor_descriptor(
351 make_tuple(G, K, C, Z, Y, X),
352 make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
353 const auto merged_desc =
360 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
361 }
362
363 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
364 static auto
365 MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
366 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
367 {
368 const index_t& G = g_k_c_wis_lengths[I0];
369 const index_t& K = g_k_c_wis_lengths[I1];
370 const index_t& C = g_k_c_wis_lengths[I2];
371 const index_t& Z = g_k_c_wis_lengths[I3];
372 const index_t& Y = g_k_c_wis_lengths[I4];
373 const index_t& X = g_k_c_wis_lengths[I5];
374
375 const index_t& GStride = g_k_c_wis_strides[I0];
376 const index_t KStride = g_k_c_wis_strides[I1];
377 const index_t CStride = 1;
378 const index_t ZStride = Y * X * C;
379 const index_t YStride = X * C;
380 const index_t XStride = C;
381
382 const auto desc = make_naive_tensor_descriptor(
383 make_tuple(G, K, C, Z, Y, X),
384 make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
385 const auto merged_desc =
392 merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
393 }
394
395 static auto TransposeInOutStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
396 const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
397 {
400 {
401 std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
402 const auto G = g_n_c_wis_lengths[I0];
403 const auto C = g_n_c_wis_lengths[I2];
404
405 g_n_c_wis_strides_transposed[I0] = C;
406 g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1];
407 g_n_c_wis_strides_transposed[I2] = I1;
408 if constexpr(NDimSpatial == 2)
409 {
410 g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C;
411 g_n_c_wis_strides_transposed[I4] = G * C;
412 }
413 else if constexpr(NDimSpatial == 3)
414 {
415 g_n_c_wis_strides_transposed[I3] =
416 g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C;
417 g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C;
418 g_n_c_wis_strides_transposed[I5] = G * C;
419 }
420 return g_n_c_wis_strides_transposed;
421 }
422 else
423 {
424 // transpose not needed
425 return g_n_c_wis_strides;
426 }
427 }
428
429 static auto
430 TransposeWeiStrides(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
431 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
432 {
435 {
436 std::array<index_t, NDimSpatial + 3> g_k_c_wis_strides_transposed = g_k_c_wis_strides;
437 const index_t C = g_k_c_wis_lengths[I2];
438
439 if constexpr(NDimSpatial == 2)
440 {
441 const index_t X = g_k_c_wis_lengths[I4];
442 g_k_c_wis_strides_transposed[I2] = 1;
443 g_k_c_wis_strides_transposed[I3] = X * C;
444 g_k_c_wis_strides_transposed[I4] = C;
445 }
446 else if constexpr(NDimSpatial == 3)
447 {
448 const index_t Y = g_k_c_wis_lengths[I4];
449 const index_t X = g_k_c_wis_lengths[I5];
450 g_k_c_wis_strides_transposed[I2] = 1;
451 g_k_c_wis_strides_transposed[I3] = Y * X * C;
452 g_k_c_wis_strides_transposed[I4] = X * C;
453 g_k_c_wis_strides_transposed[I5] = C;
454 }
455 return g_k_c_wis_strides_transposed;
456 }
457 else
458 {
459 // transpose not needed
460 return g_k_c_wis_strides;
461 }
462 }
463};
464
465} // namespace tensor_operation
466} // namespace ck
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition utility/sequence.hpp:43
Definition transform_conv_ngchw_to_nhwgc.hpp:31
static constexpr auto I1
Definition transform_conv_ngchw_to_nhwgc.hpp:33
static auto TransposeInOutStrides(const std::array< index_t, NDimSpatial+3 > &g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &g_n_c_wis_strides)
Definition transform_conv_ngchw_to_nhwgc.hpp:395
static auto TransposeWeiStrides(const std::array< ck::index_t, NDimSpatial+3 > &g_k_c_wis_lengths, const std::array< ck::index_t, NDimSpatial+3 > &g_k_c_wis_strides)
Definition transform_conv_ngchw_to_nhwgc.hpp:430
static auto MakeNGCHWTransposeDesc(const std::array< ck::index_t, NDimSpatial+3 > &g_n_c_wis_lengths, const std::array< ck::index_t, NDimSpatial+3 > &g_n_c_wis_strides, const index_t split_n_size=1)
Definition transform_conv_ngchw_to_nhwgc.hpp:41
static auto MakeGKCYXTransposeDesc(const std::array< ck::index_t, NDimSpatial+3 > &g_k_c_wis_lengths, const std::array< ck::index_t, NDimSpatial+3 > &g_k_c_wis_strides)
Definition transform_conv_ngchw_to_nhwgc.hpp:223
static constexpr auto I4
Definition transform_conv_ngchw_to_nhwgc.hpp:36
static constexpr auto I3
Definition transform_conv_ngchw_to_nhwgc.hpp:35
static auto MakeGKYXCTransposeDesc(const std::array< ck::index_t, NDimSpatial+3 > &g_k_c_wis_lengths, const std::array< ck::index_t, NDimSpatial+3 > &g_k_c_wis_strides)
Definition transform_conv_ngchw_to_nhwgc.hpp:249
static constexpr auto I5
Definition transform_conv_ngchw_to_nhwgc.hpp:37
static constexpr auto I2
Definition transform_conv_ngchw_to_nhwgc.hpp:34
static auto MakeNHWGCTransposeDesc(const std::array< ck::index_t, NDimSpatial+3 > &g_n_c_wis_lengths, const std::array< ck::index_t, NDimSpatial+3 > &g_n_c_wis_strides, const index_t split_n_size=1)
Definition transform_conv_ngchw_to_nhwgc.hpp:69
static constexpr auto I0
Definition transform_conv_ngchw_to_nhwgc.hpp:32