@@ -407,27 +407,68 @@ void add_conv2d_node(
407
407
wg_size = {wg_size[0 ] * wg_size[1 ] * wg_size[2 ], 1 , 1 };
408
408
}
409
409
410
- graph.execute_nodes ().emplace_back (new DispatchNode (
411
- graph,
412
- shader,
413
- wg_size,
414
- graph.create_local_wg_size (wg_size),
415
- // Inputs and Outputs
416
- {{out, vkapi::MemoryAccessType::WRITE},
417
- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
418
- // Shader params buffers
419
- {
420
- t_out->logical_limits_ubo (),
421
- t_in->sizes_ubo (),
422
- graph.create_params_buffer (kernel_params),
423
- graph.create_params_buffer (extra_params),
424
- graph.create_params_buffer (out_params),
425
- },
426
- // Specialization Constants
427
- {},
428
- // Resizing Logic
429
- resize_conv2d_node,
430
- {weight_data, stride, padding, dilation, transposed, output_padding}));
410
+ if (method == Conv2dMethod::Pointwise) {
411
+ const utils::ivec4 kernel_param_size_stride = {
412
+ kernel_params.kernel_size [0 ],
413
+ kernel_params.kernel_size [1 ],
414
+ kernel_params.stride [0 ],
415
+ kernel_params.stride [1 ]};
416
+
417
+ const utils::ivec4 kernel_param_pad_dial = {
418
+ kernel_params.padding [0 ],
419
+ kernel_params.padding [1 ],
420
+ kernel_params.dilation [0 ],
421
+ kernel_params.dilation [1 ]};
422
+
423
+ graph.execute_nodes ().emplace_back (new DispatchNode (
424
+ graph,
425
+ shader,
426
+ wg_size,
427
+ graph.create_local_wg_size (wg_size),
428
+ // Inputs and Outputs
429
+ {{out, vkapi::MemoryAccessType::WRITE},
430
+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431
+ // Shader params buffers
432
+ {},
433
+ // Specialization Constants
434
+ {},
435
+ // Resizing Logic
436
+ resize_conv2d_node,
437
+ {weight_data, stride, padding, dilation, transposed, output_padding},
438
+ {
439
+ graph.logical_limits_pc_of (out),
440
+ graph.sizes_pc_of (in),
441
+ PushConstantDataInfo (
442
+ &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
443
+ PushConstantDataInfo (
444
+ &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
445
+ PushConstantDataInfo (
446
+ &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
447
+ PushConstantDataInfo (&out_params, sizeof (out_params)),
448
+ }));
449
+ } else {
450
+ graph.execute_nodes ().emplace_back (new DispatchNode (
451
+ graph,
452
+ shader,
453
+ wg_size,
454
+ graph.create_local_wg_size (wg_size),
455
+ // Inputs and Outputs
456
+ {{out, vkapi::MemoryAccessType::WRITE},
457
+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458
+ // Shader params buffers
459
+ {
460
+ t_out->logical_limits_ubo (),
461
+ t_in->sizes_ubo (),
462
+ graph.create_params_buffer (kernel_params),
463
+ graph.create_params_buffer (extra_params),
464
+ graph.create_params_buffer (out_params),
465
+ },
466
+ // Specialization Constants
467
+ {},
468
+ // Resizing Logic
469
+ resize_conv2d_node,
470
+ {weight_data, stride, padding, dilation, transposed, output_padding}));
471
+ }
431
472
}
432
473
433
474
void add_conv1d_node (
0 commit comments