Skip to content

Commit

Permalink
Pass properties in whisper static pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Jan 13, 2025
1 parent 3833d82 commit a1d0954
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ ov::InferRequest DecoderCache::get_model(uint8_t input_ids_size) {
reshape_input_ids(m_decoder_model, input_ids_size);

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU");
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU", m_properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_cache.emplace(input_ids_size, compiled_model.create_infer_request());
}
Expand Down Expand Up @@ -544,14 +544,14 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
preprocess_decoder(decoder_with_past_model);

ov::CompiledModel compiled_model;
compiled_model = core.compile_model(encoder_model, "NPU");
compiled_model = core.compile_model(encoder_model, "NPU", properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();

// Will compile decoder model when it's needed
m_decoder_cache = DecoderCache(decoder_model);
m_decoder_cache = DecoderCache(decoder_model, properties);

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
compiled_model = core.compile_model(decoder_with_past_model, "NPU", properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
m_models.decoder_with_past = compiled_model.create_infer_request();

Expand Down
5 changes: 4 additions & 1 deletion src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ namespace genai {
class DecoderCache {
public:
DecoderCache() = default;
DecoderCache(std::shared_ptr<ov::Model> model) : m_decoder_model(model) {}
DecoderCache(std::shared_ptr<ov::Model> model, ov::AnyMap properties)
: m_decoder_model(model)
, m_properties(properties) {}

ov::InferRequest get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::InferRequest> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
ov::AnyMap m_properties;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
Expand Down

0 comments on commit a1d0954

Please sign in to comment.