Skip to content

Commit 5353db8

Browse files
pytorchbottrivedivivek
authored andcommitted
[ET-VK] Using push constants for conv2d pw.
Pull Request resolved: #7814 This diff is related to the use of push constants for convolutional pw (pointwise) in Executorch's Vulkan backend. This optimization improves performance and memory usage. ghstack-source-id: 263238730 @exported-using-ghexport Differential Revision: [D68400677](https://our.internmc.facebook.com/intern/diff/D68400677/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 19c5d6c commit 5353db8

File tree

2 files changed

+78
-28
lines changed

2 files changed

+78
-28
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

+16-7
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,20 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
2424
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
2525
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
2626
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
27-
${layout_declare_ubo(4, "ivec3", "out_limits")}
28-
${layout_declare_ubo(5, "ivec4", "in_sizes")}
29-
${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
30-
${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
31-
${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
27+
28+
layout(push_constant) uniform restrict Block {
29+
ivec4 out_limits;
30+
ivec4 in_sizes;
31+
ivec2 kernel_size;
32+
ivec2 stride;
33+
ivec2 padding;
34+
ivec2 dilation;
35+
ivec2 overlay_region;
36+
int in_group_size;
37+
int dummy_padding;
38+
float out_min;
39+
float out_max;
40+
};
3241

3342
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3443

@@ -70,7 +79,7 @@ void main() {
7079

7180
// If the top left position is out of bounds, then this invocation will have
7281
// no work to do.
73-
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits))) {
82+
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits.xyz))) {
7483
return;
7584
}
7685

@@ -144,7 +153,7 @@ void main() {
144153

145154
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
146155
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
147-
if (all(lessThan(ivec3(pos, gpos.z), out_limits))) {
156+
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
148157
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
149158
}
150159
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

+62-21
Original file line numberDiff line numberDiff line change
@@ -407,27 +407,68 @@ void add_conv2d_node(
407407
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
408408
}
409409

410-
graph.execute_nodes().emplace_back(new DispatchNode(
411-
graph,
412-
shader,
413-
wg_size,
414-
graph.create_local_wg_size(wg_size),
415-
// Inputs and Outputs
416-
{{out, vkapi::MemoryAccessType::WRITE},
417-
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
418-
// Shader params buffers
419-
{
420-
t_out->logical_limits_ubo(),
421-
t_in->sizes_ubo(),
422-
graph.create_params_buffer(kernel_params),
423-
graph.create_params_buffer(extra_params),
424-
graph.create_params_buffer(out_params),
425-
},
426-
// Specialization Constants
427-
{},
428-
// Resizing Logic
429-
resize_conv2d_node,
430-
{weight_data, stride, padding, dilation, transposed, output_padding}));
410+
if (method == Conv2dMethod::Pointwise) {
411+
const utils::ivec4 kernel_param_size_stride = {
412+
kernel_params.kernel_size[0],
413+
kernel_params.kernel_size[1],
414+
kernel_params.stride[0],
415+
kernel_params.stride[1]};
416+
417+
const utils::ivec4 kernel_param_pad_dial = {
418+
kernel_params.padding[0],
419+
kernel_params.padding[1],
420+
kernel_params.dilation[0],
421+
kernel_params.dilation[1]};
422+
423+
graph.execute_nodes().emplace_back(new DispatchNode(
424+
graph,
425+
shader,
426+
wg_size,
427+
graph.create_local_wg_size(wg_size),
428+
// Inputs and Outputs
429+
{{out, vkapi::MemoryAccessType::WRITE},
430+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431+
// Shader params buffers
432+
{},
433+
// Specialization Constants
434+
{},
435+
// Resizing Logic
436+
resize_conv2d_node,
437+
{weight_data, stride, padding, dilation, transposed, output_padding},
438+
{
439+
graph.logical_limits_pc_of(out),
440+
graph.sizes_pc_of(in),
441+
PushConstantDataInfo(
442+
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
443+
PushConstantDataInfo(
444+
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
445+
PushConstantDataInfo(
446+
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
447+
PushConstantDataInfo(&out_params, sizeof(out_params)),
448+
}));
449+
} else {
450+
graph.execute_nodes().emplace_back(new DispatchNode(
451+
graph,
452+
shader,
453+
wg_size,
454+
graph.create_local_wg_size(wg_size),
455+
// Inputs and Outputs
456+
{{out, vkapi::MemoryAccessType::WRITE},
457+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458+
// Shader params buffers
459+
{
460+
t_out->logical_limits_ubo(),
461+
t_in->sizes_ubo(),
462+
graph.create_params_buffer(kernel_params),
463+
graph.create_params_buffer(extra_params),
464+
graph.create_params_buffer(out_params),
465+
},
466+
// Specialization Constants
467+
{},
468+
// Resizing Logic
469+
resize_conv2d_node,
470+
{weight_data, stride, padding, dilation, transposed, output_padding}));
471+
}
431472
}
432473

433474
void add_conv1d_node(

0 commit comments

Comments
 (0)