Skip to content

Commit d0d965b

Browse files
committed
[ET-VK] Consolidate shader compilation into one vkCreateComputePipelines call
Pull Request resolved: #11345 We target the QC Adreno driver implementation of Vulkan. The Vulkan API does not enforce how QC actually uses the cache. As the plural naming of `vkCreateComputePipelines` suggests, we observed that the `createInfoCount`, `pCreateInfos` and `pPipelines` arguments above allow construction of multiple compute pipelines in one invocation. We refactor ET-VK to accumulate metadata necessary for pipeline construction and invoke vkCreateComputePipelines only once. QC's implementation maximizes the cache if we create the same number of compute pipelines in fewer invocations of vkCreateComputePipelines. This decreases model load for a sample model from 1.7s to 200ms. ghstack-source-id: 288170368 Differential Revision: [D75763660](https://our.internmc.facebook.com/intern/diff/D75763660/)
1 parent af0a246 commit d0d965b

File tree

10 files changed

+173
-2
lines changed

10 files changed

+173
-2
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
495495
builder.build_graph();
496496

497497
compute_graph->prepare();
498+
compute_graph->prepare_pipelines();
498499

499500
compute_graph->encode_prepack();
500501
compute_graph->prepack();

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,42 @@ void ComputeGraph::update_descriptor_counts(
561561
}
562562
}
563563

564+
void ComputeGraph::register_pipeline_to_create(
565+
const vkapi::ShaderInfo& shader_info,
566+
const utils::WorkgroupSize& local_workgroup_size,
567+
const vkapi::SpecVarList& spec_vars,
568+
const std::vector<PushConstantDataInfo>& push_constants) {
569+
VkDescriptorSetLayout shader_layout =
570+
context()->shader_layout_cache().retrieve(shader_info.kernel_layout);
571+
572+
uint32_t pc_offset = 0;
573+
std::array<uint8_t, kMaxPushConstantSize> pc_data;
574+
for (const auto& pc : push_constants) {
575+
pc_offset += pc.write(pc_data.data(), pc_offset, kMaxPushConstantSize);
576+
}
577+
578+
vkapi::SpecVarList spec_constants = {
579+
SV(local_workgroup_size[0u]),
580+
SV(local_workgroup_size[1u]),
581+
SV(local_workgroup_size[2u])};
582+
583+
spec_constants.append(spec_vars);
584+
585+
const vkapi::ComputePipelineCache::Key desc = {
586+
context()->pipeline_layout_cache().retrieve(shader_layout, pc_offset),
587+
context()->shader_cache().retrieve(shader_info),
588+
spec_constants};
589+
590+
if (context_->pipeline_cache().contains(desc)) {
591+
return;
592+
}
593+
auto it = pipeline_descriptors_.find(desc);
594+
if (it != pipeline_descriptors_.cend()) {
595+
return;
596+
}
597+
pipeline_descriptors_.insert(desc);
598+
}
599+
564600
utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
565601
if (is_buffer_storage(idx)) {
566602
return {uint32_t(numel_of(idx)), 1u, 1u};
@@ -670,6 +706,20 @@ void ComputeGraph::prepare() {
670706
}
671707
}
672708

709+
void ComputeGraph::prepare_pipelines() {
710+
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
711+
node->prepare_pipelines(this);
712+
}
713+
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
714+
node->prepare_pipelines(this);
715+
}
716+
context_->pipeline_cache().create_pipelines(pipeline_descriptors_);
717+
718+
pipeline_descriptors_ = std::unordered_set<
719+
vkapi::ComputePipelineCache::Key,
720+
vkapi::ComputePipelineCache::Hasher>();
721+
}
722+
673723
void ComputeGraph::encode_prepack() {
674724
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
675725
node->encode(this);

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ class ComputeGraph final {
185185
std::vector<IOValueRef> inputs_;
186186
std::vector<IOValueRef> outputs_;
187187

188+
std::unordered_set<
189+
vkapi::ComputePipelineCache::Key,
190+
vkapi::ComputePipelineCache::Hasher>
191+
pipeline_descriptors_;
192+
188193
protected:
189194
size_t values_in_use_ = 0;
190195
size_t execute_count_ = 0;
@@ -711,8 +716,16 @@ class ComputeGraph final {
711716
const vkapi::ShaderInfo& shader_info,
712717
bool execute);
713718

719+
void register_pipeline_to_create(
720+
const vkapi::ShaderInfo& shader_info,
721+
const utils::WorkgroupSize& local_workgroup_size,
722+
const vkapi::SpecVarList& spec_vars,
723+
const std::vector<PushConstantDataInfo>& push_constants);
724+
714725
void prepare();
715726

727+
void prepare_pipelines();
728+
716729
//
717730
// Dispatch Utilities
718731
//

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ DispatchNode::DispatchNode(
3535
graph.update_descriptor_counts(shader, /*execute = */ true);
3636
}
3737

38+
void DispatchNode::prepare_pipelines(ComputeGraph* graph) {
39+
graph->register_pipeline_to_create(
40+
shader_, local_workgroup_size_, spec_vars_, push_constants_);
41+
}
42+
3843
void DispatchNode::encode(ComputeGraph* graph) {
3944
if (!shader_) {
4045
return;

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class DispatchNode : public ExecuteNode {
4040

4141
~DispatchNode() override = default;
4242

43+
void prepare_pipelines(ComputeGraph* graph) override;
44+
4345
void encode(ComputeGraph* graph) override;
4446

4547
protected:

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class ExecuteNode {
6161

6262
virtual ~ExecuteNode() = default;
6363

64+
virtual void prepare_pipelines(ComputeGraph* graph) {
65+
(void)graph;
66+
}
67+
6468
virtual void encode(ComputeGraph* graph) {
6569
(void)graph;
6670
}

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
6767
return staging;
6868
}
6969

70+
void PrepackNode::prepare_pipelines(ComputeGraph* graph) {
71+
graph->register_pipeline_to_create(
72+
shader_, local_workgroup_size_, spec_vars_, push_constants_);
73+
graph->register_pipeline_to_create(
74+
noop_shader_, utils::WorkgroupSize(1, 1, 1), {}, {});
75+
}
76+
7077
void PrepackNode::encode(ComputeGraph* graph) {
7178
api::Context* const context = graph->context();
7279

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class PrepackNode final {
4040

4141
~PrepackNode() = default;
4242

43+
void prepare_pipelines(ComputeGraph* graph);
44+
4345
void encode(ComputeGraph* graph);
4446

4547
inline void set_node_id(uint32_t node_id) {

backends/vulkan/runtime/vk_api/Pipeline.cpp

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
270270
// ComputePipeline
271271
//
272272

273+
ComputePipeline::ComputePipeline(VkDevice device, VkPipeline handle)
274+
: device_{device}, handle_{handle} {}
275+
273276
ComputePipeline::ComputePipeline(
274277
VkDevice device,
275278
const ComputePipeline::Descriptor& descriptor,
@@ -444,19 +447,94 @@ ComputePipelineCache::~ComputePipelineCache() {
444447
pipeline_cache_ = VK_NULL_HANDLE;
445448
}
446449

450+
bool ComputePipelineCache::contains(const ComputePipelineCache::Key& key) {
451+
std::lock_guard<std::mutex> lock(cache_mutex_);
452+
453+
auto it = cache_.find(key);
454+
return it != cache_.cend();
455+
}
456+
457+
void ComputePipelineCache::create_pipelines(
458+
const std::unordered_set<Key, Hasher>& descriptors) {
459+
std::lock_guard<std::mutex> lock(cache_mutex_);
460+
461+
const auto num_pipelines = descriptors.size();
462+
std::vector<VkPipeline> pipelines(num_pipelines);
463+
464+
std::vector<std::vector<VkSpecializationMapEntry>> map_entries;
465+
map_entries.reserve(num_pipelines);
466+
467+
std::vector<VkSpecializationInfo> specialization_infos;
468+
specialization_infos.reserve(num_pipelines);
469+
470+
std::vector<VkPipelineShaderStageCreateInfo> shader_stage_create_infos;
471+
shader_stage_create_infos.reserve(num_pipelines);
472+
473+
std::vector<VkComputePipelineCreateInfo> create_infos;
474+
create_infos.reserve(num_pipelines);
475+
476+
for (auto& key : descriptors) {
477+
map_entries.push_back(key.specialization_constants.generate_map_entries());
478+
479+
specialization_infos.push_back(VkSpecializationInfo{
480+
key.specialization_constants.size(), // mapEntryCount
481+
map_entries.back().data(), // pMapEntries
482+
key.specialization_constants.data_nbytes(), // dataSize
483+
key.specialization_constants.data(), // pData
484+
});
485+
486+
shader_stage_create_infos.push_back(VkPipelineShaderStageCreateInfo{
487+
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
488+
nullptr, // pNext
489+
0u, // flags
490+
VK_SHADER_STAGE_COMPUTE_BIT, // stage
491+
key.shader_module, // module
492+
"main", // pName
493+
&specialization_infos.back(), // pSpecializationInfo
494+
});
495+
496+
create_infos.push_back(VkComputePipelineCreateInfo{
497+
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
498+
nullptr, // pNext
499+
0u, // flags
500+
shader_stage_create_infos.back(), // stage
501+
key.pipeline_layout, // layout
502+
VK_NULL_HANDLE, // basePipelineHandle
503+
0u, // basePipelineIndex
504+
});
505+
}
506+
507+
VK_CHECK(vkCreateComputePipelines(
508+
device_,
509+
pipeline_cache_,
510+
create_infos.size(),
511+
create_infos.data(),
512+
nullptr,
513+
pipelines.data()));
514+
515+
uint32_t i = 0;
516+
for (auto& key : descriptors) {
517+
auto it = cache_.find(key);
518+
if (it != cache_.cend()) {
519+
continue;
520+
}
521+
cache_.insert({key, ComputePipelineCache::Value(device_, pipelines[i])});
522+
++i;
523+
}
524+
}
525+
447526
VkPipeline ComputePipelineCache::retrieve(
448527
const ComputePipelineCache::Key& key) {
449528
std::lock_guard<std::mutex> lock(cache_mutex_);
450529

451530
auto it = cache_.find(key);
452-
if (cache_.cend() == it) {
531+
if (it == cache_.cend()) {
453532
it = cache_
454533
.insert(
455534
{key,
456535
ComputePipelineCache::Value(device_, key, pipeline_cache_)})
457536
.first;
458537
}
459-
460538
return it->second.handle();
461539
}
462540

backends/vulkan/runtime/vk_api/Pipeline.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <mutex>
2121
#include <unordered_map>
22+
#include <unordered_set>
2223

2324
#define SV(x) ::vkcompute::vkapi::SpecVar(x)
2425

@@ -158,6 +159,8 @@ class ComputePipeline final {
158159
SpecVarList specialization_constants;
159160
};
160161

162+
explicit ComputePipeline(VkDevice device, VkPipeline handle);
163+
161164
explicit ComputePipeline(
162165
VkDevice device,
163166
const Descriptor& descriptor,
@@ -185,6 +188,10 @@ class ComputePipeline final {
185188
// does not allow for move assignment. The swap function will
186189
// be used in the hash map.
187190
friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept;
191+
192+
friend bool operator==(
193+
const ComputePipeline::Descriptor& _1,
194+
const ComputePipeline::Descriptor& _2);
188195
};
189196

190197
class PipelineLayoutCache final {
@@ -293,6 +300,8 @@ class ComputePipelineCache final {
293300
const std::string cache_data_path_;
294301

295302
public:
303+
bool contains(const Key&);
304+
void create_pipelines(const std::unordered_set<Key, Hasher>&);
296305
VkPipeline retrieve(const Key&);
297306
void purge();
298307
};

0 commit comments

Comments
 (0)