Skip to content

Commit

Permalink
[ET-VK] Changing texture access pattern for conv2d dw ops to improve …
Browse files Browse the repository at this point in the history
…performance.

This diff changes the texture access pattern for convolutional depthwise (DW) operations in Executorch's Vulkan backend to iterate first on x axis then y and then z to improve performance.

Differential Revision: [D67770160](https://our.internmc.facebook.com/intern/diff/D67770160/)

[ghstack-poisoned]
  • Loading branch information
trivedivivek committed Jan 2, 2025
1 parent 4d9679f commit 969ebaf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* output at a single output location.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 pos = ivec3(
gl_GlobalInvocationID.x % out_limits.x,
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));

if (any(greaterThanEqual(pos, out_limits))) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* output at a single output location.
*/
void main() {
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);
const u16vec3 pos = u16vec3(
gl_GlobalInvocationID.x % out_limits.x,
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));

if (any(greaterThanEqual(pos, out_limits))) {
return;
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ void add_conv2d_node(

utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);

if (method == Conv2dMethod::Pointwise) {
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
}

Expand Down

0 comments on commit 969ebaf

Please sign in to comment.