Skip to content

Commit

Permalink
Fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Jan 9, 2025
1 parent 39d5ced commit 1cbc1a0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class CacheManager {
std::vector<ov::Tensor> m_key_cache;
std::vector<ov::Tensor> m_value_cache;
size_t m_num_allocated_kv_blocks = 0;
ov::Core m_core;
ov::InferRequest m_request;
ov::Core m_core;

ov::Shape set_first_dim_and_make_static(const ov::PartialShape& shape, size_t dim) {
ov::PartialShape res_shape = shape;
Expand Down
9 changes: 2 additions & 7 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
m_num_decoder_layers = device_config.get_num_layers();

// setup KV caches
m_cache_manager = std::make_shared<CacheManager>(device_config, core);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
infer_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_key_cache(decoder_layer_id));
infer_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_value_cache(decoder_layer_id));
}
m_cache_manager = std::make_shared<CacheManager>(device_config, infer_request, core);

SchedulerConfig updated_config = scheduler_config;
// update KV blocks number in scheduler config
Expand All @@ -85,8 +81,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
}
m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), m_cache_manager, updated_config, device_config.get_num_layers(), can_use_partial_preemption);

m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), updated_config, m_num_decoder_layers, can_use_partial_preemption);

// and finally create model runner
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
if (is_use_cache_eviction) {
Expand Down Expand Up @@ -204,6 +198,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
_register_step_cache_usage(scheduler_output.m_cache_usage);
m_pipeline_metrics.avg_cache_usage = _get_current_running_average_cache_usage();

const auto& sched_config = m_scheduler->get_config();
if (sched_config.use_cache_eviction && sched_config.cache_eviction_config.apply_rotation) {
_compute_cache_rotation_data(m_requests, scheduler_output);
m_model_runner->set_cache_rotation_data(std::move(m_current_step_rotated_block_indices_per_sequence),
Expand Down
5 changes: 1 addition & 4 deletions src/cpp/src/continuous_batching_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
const ov::AnyMap& plugin_config,
const DeviceConfig& device_config,
ov::Core& core);
void _register_step_cache_usage(float step_cache_usage);
float _get_current_running_average_cache_usage() const;
void _maybe_evict_cache_blocks(const SchedulerConfig& sched_config);
void _compute_cache_rotation_data(const std::vector<SequenceGroup::Ptr>& sequence_groups, const Scheduler::Output& scheduler_output);


/**
Expand Down Expand Up @@ -92,6 +88,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc

void _register_step_cache_usage(float step_cache_usage);
float _get_current_running_average_cache_usage() const;
void _compute_cache_rotation_data(const std::vector<SequenceGroup::Ptr>& sequence_groups, const Scheduler::Output& scheduler_output);

public:
ContinuousBatchingImpl(const std::shared_ptr<ov::Model>& model,
Expand Down
2 changes: 0 additions & 2 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ class DeviceConfig {
return m_head_size;
}

}

ov::PartialShape get_value_cache_shape(size_t id) const {
OPENVINO_ASSERT(m_value_cache_shape.size());
return m_value_cache_shape[id];
Expand Down
42 changes: 18 additions & 24 deletions tests/python_tests/test_whisper_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,10 @@ def test_max_new_tokens(model_descr, sample_from_dataset):


@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.parametrize(
"test_samples",
[
(get_samples_from_dataset(language="fr", length=1), "fr"),
(get_samples_from_dataset(language="de", length=1), "de"),
],
)
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=1, language="fr"),
*get_fixture_params_for_n_whisper_dataset_samples(n=1, language="de")], indirect=True)
@pytest.mark.precommit
def test_language_mode(model_descr, test_samples):
def test_language_mode(model_descr, sample_from_dataset):
assert genai_result.texts[0] == expected

genai_result = pipe.generate(sample_from_dataset)
Expand All @@ -347,7 +342,7 @@ def test_language_mode(model_descr, test_samples):
def test_language_mode_fr(model_descr, sample_from_dataset):
model_id, path = model_descr
model_id, path, opt_pipe, pipe = read_whisper_model(model_descr)
samples, language = test_samples
samples, language = sample_from_dataset

expected = opt_pipe(
samples[0], max_new_tokens=30, generate_kwargs={"language": language}
Expand All @@ -370,7 +365,7 @@ def test_language_mode_fr(model_descr, sample_from_dataset):
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.parametrize("sample_from_dataset", get_fixture_params_for_n_whisper_dataset_samples(n=3, language="fr"), indirect=True)
@pytest.mark.precommit
def test_task_mode(model_descr, test_sample):
def test_task_mode(model_descr, sample_from_dataset):
model_id, path, opt_pipe, pipe = read_whisper_model(model_descr)

expected = opt_pipe(
Expand Down Expand Up @@ -432,57 +427,55 @@ def test_language_autodetect(model_descr, sample_from_dataset):
run_pipeline_with_ref(
model_id=model_descr[0],
tmp_path=model_descr[1],
sample=test_sample,
sample=sample_from_dataset,
generation_config=ov_genai.WhisperGenerationConfig(max_new_tokens=30),
)


@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.parametrize("test_sample", get_samples_from_dataset(length=1))
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=1)], indirect=True)
@pytest.mark.precommit
def test_return_timestamps_short_form(model_descr, test_sample):
def test_return_timestamps_short_form(model_descr, sample_from_dataset):
run_pipeline_with_ref(
model_id=model_descr[0],
tmp_path=model_descr[1],
sample=test_sample,
sample=sample_from_dataset,
generation_config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
)


@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.parametrize("test_sample", get_samples_from_dataset(length=1))
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=1)], indirect=True)
@pytest.mark.precommit
def test_return_timestamps_max_new_tokens_short_form(model_descr, test_sample):
def test_return_timestamps_max_new_tokens_short_form(model_descr, sample_from_dataset):
run_pipeline_with_ref(
model_id=model_descr[0],
tmp_path=model_descr[1],
sample=test_sample,
sample=sample_from_dataset,
generation_config=ov_genai.WhisperGenerationConfig(
return_timestamps=True, language="en", max_new_tokens=30
),
)


@pytest.mark.parametrize("model_descr", get_whisper_models_list())
@pytest.mark.parametrize(
"test_sample", get_samples_from_dataset(length=10, long_form=True)
)
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=10, long_form=True)], indirect=True)
@pytest.mark.precommit
def test_longform_audio(model_descr, test_sample):
def test_longform_audio(model_descr, sample_from_dataset):
_, _, hf_pipe, genai_pipe = read_whisper_model(model_descr)

streamer_result = []

genai_result = run_genai(
genai_pipe,
test_sample,
sample_from_dataset,
config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
streamer=lambda x: streamer_result.append(x),
)

hf_result = run_huggingface(
hf_pipe,
test_sample,
sample_from_dataset,
config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
)

Expand Down Expand Up @@ -520,7 +513,8 @@ def test_initial_prompt_hotwords(model_descr, sample_from_dataset):
assert "Joel Keaton" in result.texts[0]
assert "Joel Kyton" not in result.texts[0]

result = pipe.generate(test_sample, initial_prompt="Joel Kyton")
result = pipe.generate(sample_from_dataset
, initial_prompt="Joel Kyton")

assert "Joel Keaton" not in result.texts[0]
assert "Joel Kyton" in result.texts[0]
Expand Down

0 comments on commit 1cbc1a0

Please sign in to comment.