@@ -407,7 +407,9 @@ void add_conv2d_node(
407
407
wg_size = {wg_size[0 ] * wg_size[1 ] * wg_size[2 ], 1 , 1 };
408
408
}
409
409
410
- if (method == Conv2dMethod::Pointwise) {
410
+ vkapi::ParamsBindList param_buffers;
411
+ std::vector<PushConstantDataInfo> push_constants;
412
+ if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
411
413
const utils::ivec4 kernel_param_size_stride = {
412
414
kernel_params.kernel_size [0 ],
413
415
kernel_params.kernel_size [1 ],
@@ -420,55 +422,43 @@ void add_conv2d_node(
420
422
kernel_params.dilation [0 ],
421
423
kernel_params.dilation [1 ]};
422
424
423
- graph.execute_nodes ().emplace_back (new DispatchNode (
424
- graph,
425
- shader,
426
- wg_size,
427
- graph.create_local_wg_size (wg_size),
428
- // Inputs and Outputs
429
- {{out, vkapi::MemoryAccessType::WRITE},
430
- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431
- // Shader params buffers
432
- {},
433
- // Specialization Constants
434
- {},
435
- // Resizing Logic
436
- resize_conv2d_node,
437
- {weight_data, stride, padding, dilation, transposed, output_padding},
438
- {
439
- graph.logical_limits_pc_of (out),
440
- graph.sizes_pc_of (in),
441
- PushConstantDataInfo (
442
- &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
443
- PushConstantDataInfo (
444
- &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
445
- PushConstantDataInfo (
446
- &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
447
- PushConstantDataInfo (&out_params, sizeof (out_params)),
448
- }));
425
+ push_constants = {
426
+ graph.logical_limits_pc_of (out),
427
+ graph.sizes_pc_of (in),
428
+ PushConstantDataInfo (
429
+ &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
430
+ PushConstantDataInfo (
431
+ &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
432
+ PushConstantDataInfo (
433
+ &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
434
+ PushConstantDataInfo (&out_params, sizeof (out_params)),
435
+ };
449
436
} else {
450
- graph.execute_nodes ().emplace_back (new DispatchNode (
451
- graph,
452
- shader,
453
- wg_size,
454
- graph.create_local_wg_size (wg_size),
455
- // Inputs and Outputs
456
- {{out, vkapi::MemoryAccessType::WRITE},
457
- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458
- // Shader params buffers
459
- {
460
- t_out->logical_limits_ubo (),
461
- t_in->sizes_ubo (),
462
- graph.create_params_buffer (kernel_params),
463
- graph.create_params_buffer (extra_params),
464
- graph.create_params_buffer (out_params),
465
- },
466
- // Specialization Constants
467
- {},
468
- // Resizing Logic
469
- resize_conv2d_node,
470
- {weight_data, stride, padding, dilation, transposed, output_padding}));
437
+ param_buffers = {
438
+ t_out->logical_limits_ubo (),
439
+ t_in->sizes_ubo (),
440
+ graph.create_params_buffer (kernel_params),
441
+ graph.create_params_buffer (extra_params),
442
+ graph.create_params_buffer (out_params),
443
+ };
471
444
}
445
+
446
+ graph.execute_nodes ().emplace_back (new DispatchNode (
447
+ graph,
448
+ shader,
449
+ wg_size,
450
+ graph.create_local_wg_size (wg_size),
451
+ // Inputs and Outputs
452
+ {{out, vkapi::MemoryAccessType::WRITE},
453
+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
454
+ // Shader params buffers
455
+ param_buffers,
456
+ // Specialization Constants
457
+ {},
458
+ // Resizing Logic
459
+ resize_conv2d_node,
460
+ {weight_data, stride, padding, dilation, transposed, output_padding},
461
+ push_constants));
472
462
}
473
463
474
464
void add_conv1d_node (
0 commit comments