diff --git a/src/gpu.cpp b/src/gpu.cpp index 6711590237a..35620344d99 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -1216,6 +1216,12 @@ int create_gpu_instance() gpu_info.support_subgroup_ballot = physicalDeviceSubgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BALLOT_BIT; gpu_info.support_subgroup_shuffle = physicalDeviceSubgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_SHUFFLE_BIT; } + + if (physicalDeviceProperties.vendorID == 0x5143) + { + // double subgroup size for fp16 + gpu_info.subgroup_size *= 2; + } } else { @@ -1232,6 +1238,10 @@ int create_gpu_instance() if (physicalDeviceProperties.vendorID == 0x8086) // intel gpu_info.subgroup_size = 32; } + + // sanitize some weird subgroup size + // though there may be 1/4/8 on some cpu or awkward gpu implementations --- nihui + gpu_info.subgroup_size = std::min(std::max(gpu_info.subgroup_size, 16u), 128u); } // cache memory properties diff --git a/src/pipeline.cpp b/src/pipeline.cpp index efdaec80bde..7cd4d03514f 100644 --- a/src/pipeline.cpp +++ b/src/pipeline.cpp @@ -115,6 +115,98 @@ void Pipeline::set_optimal_local_size_xyz(const Mat& local_size_xyz) void Pipeline::set_local_size_xyz(int w, int h, int c) { + int local_size = w * h * c; + + // be multiple of subgroup size + int subgroup_size = vkdev->info.subgroup_size(); + int local_size2 = std::max(1, local_size / subgroup_size) * subgroup_size; + for (; local_size > local_size2; local_size = w * h * c) + { + if (local_size == local_size2 * 2) + { + if (c % 2 == 0) + c /= 2; + else if (h % 2 == 0) + w /= 2; + else if (w % 2 == 0) + h /= 2; + else + c /= 2; + } + else if (local_size == local_size2 * 4) + { + if (w % 2 == 0 && h % 2 == 0) + { + w /= 2; + h /= 2; + } + else if (h % 2 == 0 && c % 2 == 0) + { + h /= 2; + c /= 2; + } + else if (w % 2 == 0 && c % 2 == 0) + { + w /= 2; + c /= 2; + } + else if (c % 4 == 0) + { + c /= 4; + } + else if (h % 4 == 0) + { + h /= 4; + } + else if (w % 4 == 0) + { + w /= 4; + } + else if (c % 2 == 0) + { + h /= 2; + c /= 2; + } + else if (h % 2 == 0) + { + w /= 2; + h /= 2; + } + else if (w % 2 == 0) + { + w /= 2; + h /= 2; + } + else + { + w /= 2; + h /= 2; + } + } + else + { + if (w % 2 != 0 && h % 2 != 0 && c % 2 != 0) + { + w /= 2; + h /= 2; + c /= 2; + } + else + { + if (c % 2 == 0) + c /= 2; + if (h % 2 == 0) + h /= 2; + if (w % 2 == 0) + w /= 2; + } + } + + w = std::max(1, w); + h = std::max(1, h); + c = std::max(1, c); + } + d->local_size_x = w; d->local_size_y = h; d->local_size_z = c;