Skip to content

Commit 3437014

Browse files
trivedivivekkirklandsign
authored andcommitted
[ET-VK] Using shared memory offsetting in conv2d pw and saving ivec3 pos instead of ivec2 to improve performance.
Pull Request resolved: #7817 This diff changes conv2d pw op shader to offset shared memory based on thread local index to improve performance. Change also saves pos as ivec3 pos instead of ivec2. ghstack-source-id: 263238733 @exported-using-ghexport Differential Revision: [D68400786](https://our.internmc.facebook.com/intern/diff/D68400786/)
1 parent 6db3f87 commit 3437014

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

Diff for: backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

+11-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#define TILE_SIZE_X ${TILE_SIZE_X}
1616
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define LOCAL_WG_SIZE 64
1718

1819
#define op(X, A, B) ${OPERATOR}
1920

@@ -42,10 +43,10 @@ layout(push_constant) uniform restrict Block {
4243

4344
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4445

45-
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
46-
// 64 is the number of threads in the local wg
47-
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
48-
shared ivec2 pos_shared[${num_shared}];
46+
// For performance improvement, reduce register usage by caching positions in shared memory.
47+
// Offset index by 1 every 16 points to avoid bank access conflict.
48+
#define offset_pos_index(index) (index + ((index) >> 4))
49+
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
4950

5051
/*
5152
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -54,7 +55,7 @@ shared ivec2 pos_shared[${num_shared}];
5455
*/
5556
void main() {
5657
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
57-
const uint shared_mem_stride = 64;
58+
const uint shared_mem_stride = LOCAL_WG_SIZE;
5859

5960
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
6061
const ivec3 gpos = ivec3(
@@ -72,7 +73,7 @@ void main() {
7273
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7374
for (int x = 0; x < TILE_SIZE_X; ++x) {
7475
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
75-
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
76+
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
7677
i++;
7778
}
7879
}
@@ -152,9 +153,10 @@ void main() {
152153
}
153154

154155
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
155-
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156-
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157-
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
156+
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
157+
const ivec3 pos = pos_shared[offset_pos_index(index)];
158+
if (all(lessThan(pos, out_limits.xyz))) {
159+
imageStore(t_out, pos, op(sum[i], out_min, out_max));
158160
}
159161
}
160162
}

0 commit comments

Comments
 (0)