diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 02df85c33e8..75726ae0892 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -495,6 +495,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { builder.build_graph(); compute_graph->prepare(); + compute_graph->prepare_pipelines(); compute_graph->encode_prepack(); compute_graph->prepack(); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 68935e63123..8e498d5f2d1 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -561,6 +561,42 @@ void ComputeGraph::update_descriptor_counts( } } +void ComputeGraph::register_pipeline_to_create( + const vkapi::ShaderInfo& shader_info, + const utils::WorkgroupSize& local_workgroup_size, + const vkapi::SpecVarList& spec_vars, + const std::vector& push_constants) { + VkDescriptorSetLayout shader_layout = + context()->shader_layout_cache().retrieve(shader_info.kernel_layout); + + uint32_t pc_offset = 0; + std::array pc_data; + for (const auto& pc : push_constants) { + pc_offset += pc.write(pc_data.data(), pc_offset, kMaxPushConstantSize); + } + + vkapi::SpecVarList spec_constants = { + SV(local_workgroup_size[0u]), + SV(local_workgroup_size[1u]), + SV(local_workgroup_size[2u])}; + + spec_constants.append(spec_vars); + + const vkapi::ComputePipelineCache::Key desc = { + context()->pipeline_layout_cache().retrieve(shader_layout, pc_offset), + context()->shader_cache().retrieve(shader_info), + spec_constants}; + + if (context_->pipeline_cache().contains(desc)) { + return; + } + auto it = pipeline_descriptors_.find(desc); + if (it != pipeline_descriptors_.cend()) { + return; + } + pipeline_descriptors_.insert(desc); +} + utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) { if (is_buffer_storage(idx)) { return {uint32_t(numel_of(idx)), 1u, 1u}; @@ -670,6 +706,20 @@ void ComputeGraph::prepare() { } } +void ComputeGraph::prepare_pipelines() { + for (std::unique_ptr& node : prepack_nodes_) { + node->prepare_pipelines(this); + } + for (std::unique_ptr& node : execute_nodes_) { + node->prepare_pipelines(this); + } + context_->pipeline_cache().create_pipelines(pipeline_descriptors_); + + pipeline_descriptors_ = std::unordered_set< + vkapi::ComputePipelineCache::Key, + vkapi::ComputePipelineCache::Hasher>(); +} + void ComputeGraph::encode_prepack() { for (std::unique_ptr& node : prepack_nodes_) { node->encode(this); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 0e8a5eba51f..90f89ea18d6 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -185,6 +185,11 @@ class ComputeGraph final { std::vector inputs_; std::vector outputs_; + std::unordered_set< + vkapi::ComputePipelineCache::Key, + vkapi::ComputePipelineCache::Hasher> + pipeline_descriptors_; + protected: size_t values_in_use_ = 0; size_t execute_count_ = 0; @@ -711,8 +716,16 @@ class ComputeGraph final { const vkapi::ShaderInfo& shader_info, bool execute); + void register_pipeline_to_create( + const vkapi::ShaderInfo& shader_info, + const utils::WorkgroupSize& local_workgroup_size, + const vkapi::SpecVarList& spec_vars, + const std::vector& push_constants); + void prepare(); + void prepare_pipelines(); + // // Dispatch Utilities // diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index a0d3a4c2e5c..b5644cf3dcd 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -35,6 +35,11 @@ DispatchNode::DispatchNode( graph.update_descriptor_counts(shader, /*execute = */ true); } +void DispatchNode::prepare_pipelines(ComputeGraph* graph) { + graph->register_pipeline_to_create( + shader_, local_workgroup_size_, spec_vars_, push_constants_); +} + void DispatchNode::encode(ComputeGraph* graph) { if (!shader_) { return; diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index db95adfee39..b6eb8624c26 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -40,6 +40,8 @@ class DispatchNode : public ExecuteNode { ~DispatchNode() override = default; + void prepare_pipelines(ComputeGraph* graph) override; + void encode(ComputeGraph* graph) override; protected: diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 0731722e13a..6a815b246ef 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -61,6 +61,10 @@ class ExecuteNode { virtual ~ExecuteNode() = default; + virtual void prepare_pipelines(ComputeGraph* graph) { + (void)graph; + } + virtual void encode(ComputeGraph* graph) { (void)graph; } diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index d84d893540c..bdbecc866ab 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -67,6 +67,13 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { return staging; } +void PrepackNode::prepare_pipelines(ComputeGraph* graph) { + graph->register_pipeline_to_create( + shader_, local_workgroup_size_, spec_vars_, push_constants_); + graph->register_pipeline_to_create( + noop_shader_, utils::WorkgroupSize(1, 1, 1), {}, {}); +} + void PrepackNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index a45deb9ff70..8ce8ac9f773 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -40,6 +40,8 @@ class PrepackNode final { ~PrepackNode() = default; + void prepare_pipelines(ComputeGraph* graph); + void encode(ComputeGraph* graph); inline void set_node_id(uint32_t node_id) { diff --git a/backends/vulkan/runtime/vk_api/Pipeline.cpp b/backends/vulkan/runtime/vk_api/Pipeline.cpp index 5dcb00168b2..994b46b8c76 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.cpp +++ b/backends/vulkan/runtime/vk_api/Pipeline.cpp @@ -270,6 +270,9 @@ void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept { // ComputePipeline // +ComputePipeline::ComputePipeline(VkDevice device, VkPipeline handle) + : device_{device}, handle_{handle} {} + ComputePipeline::ComputePipeline( VkDevice device, const ComputePipeline::Descriptor& descriptor, @@ -444,19 +447,94 @@ ComputePipelineCache::~ComputePipelineCache() { pipeline_cache_ = VK_NULL_HANDLE; } +bool ComputePipelineCache::contains(const ComputePipelineCache::Key& key) { + std::lock_guard lock(cache_mutex_); + + auto it = cache_.find(key); + return it != cache_.cend(); +} + +void ComputePipelineCache::create_pipelines( + const std::unordered_set& descriptors) { + std::lock_guard lock(cache_mutex_); + + const auto num_pipelines = descriptors.size(); + std::vector pipelines(num_pipelines); + + std::vector> map_entries; + map_entries.reserve(num_pipelines); + + std::vector specialization_infos; + specialization_infos.reserve(num_pipelines); + + std::vector shader_stage_create_infos; + shader_stage_create_infos.reserve(num_pipelines); + + std::vector create_infos; + create_infos.reserve(num_pipelines); + + for (auto& key : descriptors) { + map_entries.push_back(key.specialization_constants.generate_map_entries()); + + specialization_infos.push_back(VkSpecializationInfo{ + key.specialization_constants.size(), // mapEntryCount + map_entries.back().data(), // pMapEntries + key.specialization_constants.data_nbytes(), // dataSize + key.specialization_constants.data(), // pData + }); + + shader_stage_create_infos.push_back(VkPipelineShaderStageCreateInfo{ + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + VK_SHADER_STAGE_COMPUTE_BIT, // stage + key.shader_module, // module + "main", // pName + &specialization_infos.back(), // pSpecializationInfo + }); + + create_infos.push_back(VkComputePipelineCreateInfo{ + VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + shader_stage_create_infos.back(), // stage + key.pipeline_layout, // layout + VK_NULL_HANDLE, // basePipelineHandle + 0u, // basePipelineIndex + }); + } + + VK_CHECK(vkCreateComputePipelines( + device_, + pipeline_cache_, + create_infos.size(), + create_infos.data(), + nullptr, + pipelines.data())); + + uint32_t i = 0; + for (auto& key : descriptors) { + auto it = cache_.find(key); + if (it != cache_.cend()) { + continue; + } + cache_.insert({key, ComputePipelineCache::Value(device_, pipelines[i])}); + ++i; + } +} + VkPipeline ComputePipelineCache::retrieve( const ComputePipelineCache::Key& key) { std::lock_guard lock(cache_mutex_); auto it = cache_.find(key); - if (cache_.cend() == it) { + if (it == cache_.cend()) { it = cache_ .insert( {key, ComputePipelineCache::Value(device_, key, pipeline_cache_)}) .first; } - return it->second.handle(); } diff --git a/backends/vulkan/runtime/vk_api/Pipeline.h b/backends/vulkan/runtime/vk_api/Pipeline.h index 1e0fc1e28aa..67dfaebe75b 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.h +++ b/backends/vulkan/runtime/vk_api/Pipeline.h @@ -19,6 +19,7 @@ #include #include +#include #define SV(x) ::vkcompute::vkapi::SpecVar(x) @@ -158,6 +159,8 @@ class ComputePipeline final { SpecVarList specialization_constants; }; + explicit ComputePipeline(VkDevice device, VkPipeline handle); + explicit ComputePipeline( VkDevice device, const Descriptor& descriptor, @@ -185,6 +188,10 @@ class ComputePipeline final { // does not allow for move assignment. The swap function will // be used in the hash map. friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept; + + friend bool operator==( + const ComputePipeline::Descriptor& _1, + const ComputePipeline::Descriptor& _2); }; class PipelineLayoutCache final { @@ -293,6 +300,8 @@ class ComputePipelineCache final { const std::string cache_data_path_; public: + bool contains(const Key&); + void create_pipelines(const std::unordered_set&); VkPipeline retrieve(const Key&); void purge(); }; diff --git a/backends/vulkan/runtime/vk_api/Runtime.cpp b/backends/vulkan/runtime/vk_api/Runtime.cpp index 9b30aaacda8..2073a7364ba 100644 --- a/backends/vulkan/runtime/vk_api/Runtime.cpp +++ b/backends/vulkan/runtime/vk_api/Runtime.cpp @@ -258,7 +258,8 @@ uint32_t select_first(const std::vector& devices) { // Global runtime initialization // -std::unique_ptr init_global_vulkan_runtime() { +std::unique_ptr init_global_vulkan_runtime( + const std::string& cache_data_path) { // Load Vulkan drivers #if defined(USE_VULKAN_VOLK) if (VK_SUCCESS != volkInitialize()) { @@ -278,7 +279,6 @@ std::unique_ptr init_global_vulkan_runtime() { #endif /* VULKAN_DEBUG */ const bool init_default_device = true; const uint32_t num_requested_queues = 1; // TODO: raise this value - const std::string cache_data_path = ""; // TODO: expose to client const RuntimeConfig default_config{ enable_validation_messages, @@ -377,13 +377,24 @@ uint32_t Runtime::create_adapter(const Selector& selector) { return adapter_i; } +std::string& set_and_get_pipeline_cache_data_path( + const std::string& file_path) { + // The global cache data path is declared as a static local variable for the + // same reasons as the global runtime below. + static std::string global_cache_data_path; + if (file_path.size() > 0) { + global_cache_data_path = file_path; + } + return global_cache_data_path; +} + Runtime* runtime() { // The global vulkan runtime is declared as a static local variable within a // non-static function to ensure it has external linkage. If it were a global // static variable there would be one copy per translation unit that includes // Runtime.h as it would have internal linkage. static const std::unique_ptr p_runtime = - init_global_vulkan_runtime(); + init_global_vulkan_runtime(set_and_get_pipeline_cache_data_path("")); VK_CHECK_COND( p_runtime, diff --git a/backends/vulkan/runtime/vk_api/Runtime.h b/backends/vulkan/runtime/vk_api/Runtime.h index 16f1400021c..c1b67c0dbdc 100644 --- a/backends/vulkan/runtime/vk_api/Runtime.h +++ b/backends/vulkan/runtime/vk_api/Runtime.h @@ -100,6 +100,8 @@ class Runtime final { uint32_t create_adapter(const Selector&); }; +std::string& set_and_get_pipeline_cache_data_path(const std::string& file_path); + // The global runtime is retrieved using this function, where it is declared as // a static local variable. Runtime* runtime();