Skip to content

Commit b818123

Browse files
authored
[ET-VK] Splitting TILE_SIZE to TILE_SIZE_X and TILE_SIZE_Y in conv2d pw.
Differential Revision: D68400783 Pull Request resolved: #7816
1 parent 337fdd5 commit b818123

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

+15-15
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15-
#define TILE_SIZE ${TILE_SIZE}
15+
#define TILE_SIZE_X ${TILE_SIZE_X}
16+
#define TILE_SIZE_Y ${TILE_SIZE_Y}
1617

1718
#define op(X, A, B) ${OPERATOR}
1819

@@ -43,7 +44,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4344

4445
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
4546
// 64 is the number of threads in the local wg
46-
$num_shared = 64 * TILE_SIZE * TILE_SIZE
47+
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
4748
shared ivec2 pos_shared[${num_shared}];
4849

4950
/*
@@ -52,8 +53,8 @@ shared ivec2 pos_shared[${num_shared}];
5253
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
5354
*/
5455
void main() {
55-
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
56-
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
56+
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
57+
const uint shared_mem_stride = 64;
5758

5859
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
5960
const ivec3 gpos = ivec3(
@@ -67,11 +68,10 @@ void main() {
6768
// +--------+--------+
6869
// | pos[2] | pos[3] |
6970
// +--------+--------+
70-
ivec2 pos[TILE_SIZE * TILE_SIZE];
71-
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
72-
for (int x = 0; x < TILE_SIZE; ++x) {
73-
pos[i] = ivec2(
74-
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
71+
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
72+
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
73+
for (int x = 0; x < TILE_SIZE_X; ++x) {
74+
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
7575
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
7676
i++;
7777
}
@@ -86,14 +86,14 @@ void main() {
8686
// Compute the index of the input texture that needs to be loaded for each
8787
// output position. Note that negative indices can be produced indicating that
8888
// the top-left element is in a region added by padding.
89-
ivec2 ipos[TILE_SIZE * TILE_SIZE];
90-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
89+
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
90+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
9191
ipos[i] = pos[i] * stride - padding;
9292
}
9393

94-
vec4 sum[TILE_SIZE * TILE_SIZE];
94+
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
9595
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
96-
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
96+
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
9797
sum[i] = sum[0];
9898
}
9999

@@ -109,7 +109,7 @@ void main() {
109109
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
110110

111111
#pragma unroll
112-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
112+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
113113
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
114114
// For 2x2 tile size algorithm works as follows.
115115
// To explain the calculations below, the contents of one in_tex and the
@@ -151,7 +151,7 @@ void main() {
151151
}
152152
}
153153

154-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
154+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
155155
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156156
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157157
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ conv2d_pw:
99
OPERATOR: X
1010
NDIM: 3
1111
DTYPE: float
12-
TILE_SIZE: 2
12+
TILE_SIZE_X: 2
13+
TILE_SIZE_Y: 2
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half

0 commit comments

Comments
 (0)