Skip to content

Commit

Permalink
[ET-VK] Minor improvement to q_linear op shader. (pytorch#7932)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#7728

This diff contains a minor improvement to the q_linear op shader in the Vulkan backend for Executorch. The code changes in the q_8w_linear.glsl file include a change in position parameter from a 3-element u16vec3 to a 2-element u16vec2.
ghstack-source-id: 262853434

Differential Revision: [D68113154](https://our.internmc.facebook.com/intern/diff/D68113154/)

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
2 people authored and Zonglin Peng committed Jan 30, 2025
1 parent f24ff01 commit e3e41d4
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void main() {

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) {
VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);

VEC4_T outtex = VEC4_T(0);
Expand All @@ -101,7 +101,7 @@ VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) {
const VEC4_T scales = load_texel(t_scales, scales_pos);

for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) {
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.yz));
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
const VEC4_T sums = VEC4_T(
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))),
Expand All @@ -117,16 +117,15 @@ VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) {
}

void main() {
const u16vec3 out_pos = u16vec3(
const u16vec2 out_pos = u16vec2(
gl_GlobalInvocationID.x / out_limits.y,
gl_GlobalInvocationID.x % out_limits.y,
0);
gl_GlobalInvocationID.x % out_limits.y);
if (out_pos.x >= out_limits.x) {
return;
}

VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
write_texel(t_out, out_pos, outtex);
write_texel(t_out, u16vec3(out_pos, 0), outtex);
}

#endif

0 comments on commit e3e41d4

Please sign in to comment.