Skip to content

Commit 051d1a4

Browse files
[ET-VK] Using shared memory to save position in conv2d dw output op. (#8007)
Pull Request resolved: #7923 This diff introduces a change to conv2d dw op to save output positions in shared memory, which reduces register usage and improves performance. ghstack-source-id: 263440666 @exported-using-ghexport Differential Revision: [D68400890](https://our.internmc.facebook.com/intern/diff/D68400890/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 15c772c commit 051d1a4

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
2222

23+
#define LOCAL_WG_SIZE 64
24+
2325
#define op(X, A, B) ${OPERATOR}
2426

2527
#include "indexing_utils.h"
@@ -38,6 +40,11 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3840

3941
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4042

43+
// For performance improvement, reduce register usage by caching positions in shared memory.
44+
// Offset index by 1 every 16 points to avoid bank access conflict.
45+
#define offset_pos_index(index) (index + ((index) >> 4))
46+
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)];
47+
4148
/*
4249
* Computes a depthwise convolution. Each shader invocation calculates the
4350
* output at a single output location.
@@ -63,6 +70,8 @@ void main() {
6370
return;
6471
}
6572

73+
pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos;
74+
6675
// Compute the index of the top-left element of the overlay region. Negative
6776
// indices indicate that the top-left element is in a region added by padding.
6877
const ivec2 ipos = pos.xy * stride - padding;
@@ -109,18 +118,19 @@ void main() {
109118
for (int j = 0; j < TILE_SIZE; j++, kx++) {
110119
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
111120
for (int s = 0; s < BATCH_SIZE_X; s++) {
112-
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
121+
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
113122
}
114123
}
115124
}
116125
}
117126

127+
const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)];
118128
for (int y = 0; y < BATCH_SIZE_Y; y++) {
119129
for (int x = 0; x < BATCH_SIZE_X; x++) {
120-
if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) {
130+
if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits))) {
121131
continue;
122132
}
123-
imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
133+
imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max));
124134
}
125135
}
126136
}

0 commit comments

Comments
 (0)