From f5f5f65c4dfaf55ef6325426e8ad350972ffebab Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 23 Jan 2025 22:17:03 -0800 Subject: [PATCH] [ET-VK] Minor improvement to q_linear op shader. Pull Request resolved: https://github.com/pytorch/executorch/pull/7728 This diff contains a minor improvement to the q_linear op shader in the Vulkan backend for Executorch. The code changes in the q_8w_linear.glsl file include a change in position parameter from a 3-element u16vec3 to a 2-element u16vec2. ghstack-source-id: 262853434 Differential Revision: [D68113154](https://our.internmc.facebook.com/intern/diff/D68113154/) --- .../vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 9750507a188..cd1a08909d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -92,7 +92,7 @@ void main() { #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) { +VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) { const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4); VEC4_T outtex = VEC4_T(0); @@ -101,7 +101,7 @@ VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) { const VEC4_T scales = load_texel(t_scales, scales_pos); for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) { - const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.yz)); + const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); const VEC4_T sums = VEC4_T( dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))), dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))), @@ -117,16 +117,15 @@ VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) { } void main() { - const u16vec3 out_pos = u16vec3( + const u16vec2 out_pos = u16vec2( gl_GlobalInvocationID.x / out_limits.y, - gl_GlobalInvocationID.x % out_limits.y, - 0); + gl_GlobalInvocationID.x % out_limits.y); if (out_pos.x >= out_limits.x) { return; } VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x)); - write_texel(t_out, out_pos, outtex); + write_texel(t_out, u16vec3(out_pos, 0), outtex); } #endif