6#include <hip/hip_runtime.h>
20 hipDeviceProp_t dev_prop;
25 num_cu_ = dev_prop.multiProcessorCount;
33 const int max_capacity = max_occupancy * device_properties.
num_cu_;
36 const auto optimal_split =
37 static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
40 k_batch = optimal_split;
45 std::cout <<
"[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
46 << max_occupancy << std::endl;
47 std::cout <<
"[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
48 std::cout <<
"[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
53template <ck::index_t NDimSpatial>
56 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
63 constexpr index_t spatial_offset = 3;
64 const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
65 end(a_g_n_k_wos_lengths),
68 const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
71 const auto gemmM = e_g_k_c_xs_lengths[I1];
75 const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
76 end(e_g_k_c_xs_lengths),
79 const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
80 return std::make_tuple(gemmM, gemmN, gemmK);
83template <ck::index_t MPerBlock, ck::index_t NPerBlock>
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
Definition convolution_backward_data_specialization.hpp:7
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
Definition split_k_utils.hpp:17
DeviceProperties()
Definition split_k_utils.hpp:18
int num_cu_
Definition split_k_utils.hpp:27
#define CK_ENV(name)
Definition utility/env.hpp:129