diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 4bfe1f7d2e8..81c92aeb6f1 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -182,56 +182,78 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vectorconvert_packing(A0, A, 1, cmd, opt); vkdev->convert_packing(B0, B, 1, cmd, opt); - vkdev->convert_packing(C0, C, 1, cmd, opt); const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; - int broadcast_type_C; + VkMat C; + int broadcast_type_C = -1; if (constantC) { + vkdev->convert_packing(C_data_gpu, C, 1, cmd, opt); broadcast_type_C = constant_broadcast_type_C; } else { - if (C.dims == 1 && C.w == 1) - { - // scalar - broadcast_type_C = 0; - } - if (C.dims == 1 && C.w == M) + VkMat C0; + if (constantA && constantB) { - // M - // auto broadcast from h to w is the ncnn-style convention - broadcast_type_C = 1; + C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkMat(); } - if (C.dims == 1 && C.w == N) + else if (constantA) { - // N - broadcast_type_C = 4; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat(); } - if (C.dims == 2 && C.w == 1 && C.h == M) + else if (constantB) { - // Mx1 - broadcast_type_C = 2; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat(); } - if (C.dims == 2 && C.w == N && C.h == M) + else { - // MxN - broadcast_type_C = 3; + C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkMat(); } - if (C.dims == 2 && C.w == N && C.h == 1) + + if (!C0.empty()) { - // 1xN - broadcast_type_C = 4; + vkdev->convert_packing(C0, C, 1, cmd, opt); + + if (C0.dims == 1 && C0.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C0.dims == 1 && C0.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C0.dims == 1 && C0.w == N) + { + // N + broadcast_type_C = 4; + } + if (C0.dims == 2 && C0.w == 1 && C0.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C0.dims == 2 && C0.w == N && C0.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C0.dims == 2 && C0.w == N && C0.h == 1) + { + // 1xN + broadcast_type_C = 4; + } } } @@ -314,56 +336,78 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vecto { const VkImageMat& A0 = constantA ? A_data_gpu_image : bottom_blobs[0]; const VkImageMat& B0 = constantB ? B_data_gpu_image : constantA ? bottom_blobs[0] : bottom_blobs[1]; - const VkImageMat& C0 = constantC ? C_data_gpu_image : bottom_blobs[bottom_blobs.size() - 1]; VkImageMat A; VkImageMat B; - VkImageMat C; vkdev->convert_packing(A0, A, 1, cmd, opt); vkdev->convert_packing(B0, B, 1, cmd, opt); - vkdev->convert_packing(C0, C, 1, cmd, opt); const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; - int broadcast_type_C; + VkImageMat C; + int broadcast_type_C = -1; if (constantC) { + vkdev->convert_packing(C_data_gpu_image, C, 1, cmd, opt); broadcast_type_C = constant_broadcast_type_C; } else { - if (C.dims == 1 && C.w == 1) - { - // scalar - broadcast_type_C = 0; - } - if (C.dims == 1 && C.w == M) + VkImageMat C0; + if (constantA && constantB) { - // M - // auto broadcast from h to w is the ncnn-style convention - broadcast_type_C = 1; + C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkImageMat(); } - if (C.dims == 1 && C.w == N) + else if (constantA) { - // N - broadcast_type_C = 4; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat(); } - if (C.dims == 2 && C.w == 1 && C.h == M) + else if (constantB) { - // Mx1 - broadcast_type_C = 2; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat(); } - if (C.dims == 2 && C.w == N && C.h == M) + else { - // MxN - broadcast_type_C = 3; + C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkImageMat(); } - if (C.dims == 2 && C.w == N && C.h == 1) + + if (!C0.empty()) { - // 1xN - broadcast_type_C = 4; + vkdev->convert_packing(C0, C, 1, cmd, opt); + + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } } }