diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index b392e44b3b..6a2e2e407b 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -11,6 +11,45 @@ #include "generation_config_helper.hpp" +namespace { + template + struct json_type_traits {}; + + template <> + struct json_type_traits { static constexpr auto json_value_t = nlohmann::json::value_t::number_integer; }; + + template <> + struct json_type_traits { static constexpr auto json_value_t = nlohmann::json::value_t::number_integer; }; + + template <> + struct json_type_traits { static constexpr auto json_value_t = nlohmann::json::value_t::number_unsigned; }; + + template <> + struct json_type_traits { static constexpr auto json_value_t = nlohmann::json::value_t::number_float; }; + + template <> + struct json_type_traits { static constexpr auto json_value_t = nlohmann::json::value_t::string; }; + + template <> + struct json_type_traits { static constexpr auto json_value_t = nlohmann::json::value_t::boolean; }; + + template + void read_json_param(const nlohmann::json& data, const std::string& name, T& param) { + if (data.contains(name) && data[name].type() == json_type_traits::json_value_t) { + param = data[name]; + } + } + + template + void read_anymap_param(const ov::AnyMap& config_map, const std::string& name, T& param) { + if (config_map.count(name)) { + param = config_map.at(name).as(); + } + } + +} // namespace + + namespace ov { GenerationConfig::GenerationConfig(std::string json_path) { @@ -18,32 +57,27 @@ GenerationConfig::GenerationConfig(std::string json_path) { OPENVINO_ASSERT(f.is_open(), "Failed to open '" + json_path + "' with generation config"); nlohmann::json data = nlohmann::json::parse(f); - - if (data.contains("max_new_tokens")) max_new_tokens = data["max_new_tokens"]; - if (data.contains("max_length")) max_length = data["max_length"]; + + read_json_param(data, "max_new_tokens", max_new_tokens); + read_json_param(data, "max_length", max_length); // note that ignore_eos is not present in HF GenerationConfig - if (data.contains("num_beam_groups")) num_beam_groups = data["num_beam_groups"]; - if (data.contains("num_beams")) num_beams = data["num_beams"]; - if (data.contains("diversity_penalty")) diversity_penalty = data["diversity_penalty"]; - if (data.contains("length_penalty")) length_penalty = data["length_penalty"]; - if (data.contains("num_return_sequences")) num_return_sequences = data["num_return_sequences"]; - if (data.contains("no_repeat_ngram_size")) no_repeat_ngram_size = data["no_repeat_ngram_size"]; + read_json_param(data, "num_beam_groups", num_beam_groups); + read_json_param(data, "num_beams", num_beams); + read_json_param(data, "diversity_penalty", diversity_penalty); + read_json_param(data, "length_penalty", length_penalty); + read_json_param(data, "num_return_sequences", num_return_sequences); + read_json_param(data, "no_repeat_ngram_size", no_repeat_ngram_size); // stop_criteria will be processed below - if (data.contains("temperature")) temperature = data["temperature"]; - if (data.contains("top_p")) top_p = data["top_p"]; - if (data.contains("top_k")) top_k = data["top_k"]; - if (data.contains("do_sample")) do_sample = data["do_sample"]; - if (data.contains("repetition_penalty")) repetition_penalty = data["repetition_penalty"]; - if (data.contains("pad_token_id")) pad_token_id = data["pad_token_id"]; - if (data.contains("bos_token_id")) bos_token_id = data["bos_token_id"]; - - if (data.contains("eos_token_id") && data["eos_token_id"].type() == nlohmann::json::value_t::number_integer) { - // todo: qwen contains several eos_token_id - eos_token_id = data["eos_token_id"]; - } - - if (data.contains("bos_token")) bos_token = data["bos_token"]; - if (data.contains("eos_token")) eos_token = data["eos_token"]; + read_json_param(data, "temperature", temperature); + read_json_param(data, "top_p", top_p); + read_json_param(data, "top_k", top_k); + read_json_param(data, "do_sample", do_sample); + read_json_param(data, "repetition_penalty", repetition_penalty); + read_json_param(data, "pad_token_id", pad_token_id); + read_json_param(data, "bos_token_id", bos_token_id); + read_json_param(data, "eos_token_id", eos_token_id); + read_json_param(data, "bos_token", bos_token); + read_json_param(data, "eos_token", eos_token); if (data.contains("early_stopping")) { auto field_type = data["early_stopping"].type(); @@ -59,28 +93,27 @@ GenerationConfig::GenerationConfig(std::string json_path) { GenerationConfig GenerationConfigHelper::anymap_to_generation_config(const ov::AnyMap& config_map) { GenerationConfig config = m_config; - - if (config_map.count("max_new_tokens")) config.max_new_tokens = config_map.at("max_new_tokens").as(); - if (config_map.count("max_length")) config.max_length = config_map.at("max_length").as(); - if (config_map.count("ignore_eos")) config.ignore_eos = config_map.at("ignore_eos").as(); - if (config_map.count("num_beam_groups")) config.num_beam_groups = config_map.at("num_beam_groups").as(); - if (config_map.count("num_beams")) config.num_beams = config_map.at("num_beams").as(); - if (config_map.count("diversity_penalty")) config.diversity_penalty = config_map.at("diversity_penalty").as(); - if (config_map.count("length_penalty")) config.length_penalty = config_map.at("length_penalty").as(); - if (config_map.count("num_return_sequences")) config.num_return_sequences = config_map.at("num_return_sequences").as(); - if (config_map.count("no_repeat_ngram_size")) config.no_repeat_ngram_size = config_map.at("no_repeat_ngram_size").as(); - if (config_map.count("stop_criteria")) config.stop_criteria = config_map.at("stop_criteria").as(); - if (config_map.count("temperature")) config.temperature = config_map.at("temperature").as(); - if (config_map.count("top_p")) config.top_p = config_map.at("top_p").as(); - if (config_map.count("top_k")) config.top_k = config_map.at("top_k").as(); - if (config_map.count("do_sample")) config.do_sample = config_map.at("do_sample").as(); - if (config_map.count("repetition_penalty")) config.repetition_penalty = config_map.at("repetition_penalty").as(); - if (config_map.count("pad_token_id")) config.pad_token_id = config_map.at("pad_token_id").as(); - if (config_map.count("bos_token_id")) config.bos_token_id = config_map.at("bos_token_id").as(); - if (config_map.count("eos_token_id")) config.eos_token_id = config_map.at("eos_token_id").as(); - if (config_map.count("bos_token")) config.bos_token = config_map.at("bos_token").as(); - if (config_map.count("eos_token")) config.eos_token = config_map.at("eos_token").as(); - + read_anymap_param(config_map, "max_new_tokens", config.max_new_tokens); + read_anymap_param(config_map, "max_length", config.max_length); + read_anymap_param(config_map, "ignore_eos", config.ignore_eos); + read_anymap_param(config_map, "num_beam_groups", config.num_beam_groups); + read_anymap_param(config_map, "num_beams", config.num_beams); + read_anymap_param(config_map, "diversity_penalty", config.diversity_penalty); + read_anymap_param(config_map, "length_penalty", config.length_penalty); + read_anymap_param(config_map, "num_return_sequences", config.num_return_sequences); + read_anymap_param(config_map, "no_repeat_ngram_size", config.no_repeat_ngram_size); + read_anymap_param(config_map, "stop_criteria", config.stop_criteria); + read_anymap_param(config_map, "temperature", config.temperature); + read_anymap_param(config_map, "top_p", config.top_p); + read_anymap_param(config_map, "top_k", config.top_k); + read_anymap_param(config_map, "do_sample", config.do_sample); + read_anymap_param(config_map, "repetition_penalty", config.repetition_penalty); + read_anymap_param(config_map, "pad_token_id", config.pad_token_id); + read_anymap_param(config_map, "bos_token_id", config.bos_token_id); + read_anymap_param(config_map, "eos_token_id", config.eos_token_id); + read_anymap_param(config_map, "bos_token", config.bos_token); + read_anymap_param(config_map, "eos_token", config.eos_token); + return config; } diff --git a/src/tests/python_tests/test_generate_api.py b/src/tests/python_tests/test_generate_api.py new file mode 100644 index 0000000000..3e652d1fe3 --- /dev/null +++ b/src/tests/python_tests/test_generate_api.py @@ -0,0 +1,102 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +model_ids = [ + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/"), + # ("meta-llama/Llama-2-7b-chat-hf", "Llama-2-7b-chat-hf/pytorch/dldt/FP16/"), + # ("microsoft/phi-1_5", "phi-1_5/"), + # ("google/gemma-2b-it", "gemma-2b-it/pytorch/dldt/FP16/"), +] + + +def run_hf_ov_genai_comparison(model_fixture, generation_config, prompt): + model_id, path, tokenizer, model = model_fixture + + generation_config_hf = generation_config.copy() + # in OpenVINO GenAI this parameter is called stop_criteria, + # while in HF it's called early_stopping. + # HF values True, False and "never" correspond to OV GenAI values "early", "heuristic" and "never" + if generation_config_hf.get('stop_criteria'): + generation_config_hf['early_stopping'] = stop_criteria_map()[generation_config_hf.pop('stop_criteria')] + + encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True) + hf_encoded_output = model.generate(encoded_prompt, **generation_config_hf) + hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) + + import sys + # sys.path.append('../../src/python/openvino_genai/') + sys.path.append('/home/epavel/devel/openvino.genai/src/python/openvino_genai/') + import py_generate_pipeline as genai + + pipe = genai.LLMPipeline(path) + ov_output = pipe.generate(prompt, **generation_config) + + if hf_output != ov_output: + print(f'hf_output: {hf_output}') + print(f'ov_output: {ov_output}') + + assert hf_output == ov_output + + +def stop_criteria_map(): + return {"never": "never", "early": True, "heuristic": False} + + +@pytest.fixture(scope="module", params=model_ids) +def model_fixture(request): + model_id, path = request.param + from transformers import AutoTokenizer, AutoModelForCausalLM + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + return model_id, path, tokenizer, model + + +test_cases = [ + (dict(max_new_tokens=20, do_sample=False), 'table is made of'), # generation_config, prompt + (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), + (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'), + (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'), + (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), + (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'The Sun is yellow because'), + (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'), +] +@pytest.mark.parametrize("generation_config,prompt", test_cases) +def test_greedy_decoding(model_fixture, generation_config, prompt): + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) + + +prompts = ['The Sun is yellow because'] #, 'Alan Turing was a', 'table is made of'] +@pytest.mark.parametrize("num_beam_groups", [2, 3]) +@pytest.mark.parametrize("group_size", [5, 3]) +@pytest.mark.parametrize("max_new_tokens", [20, 15]) +@pytest.mark.parametrize("diversity_penalty", [1.0, 1.5]) +@pytest.mark.parametrize("prompt", prompts) +def test_beam_search_decoding(model_fixture, num_beam_groups, group_size, + max_new_tokens, diversity_penalty, prompt): + generation_config = dict( + num_beam_groups=num_beam_groups, + num_beams=num_beam_groups * group_size, + diversity_penalty=diversity_penalty, + num_return_sequences=num_beam_groups * group_size, + max_new_tokens=max_new_tokens, + ) + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) + + +@pytest.mark.parametrize("stop_criteria", ["never", "early", "heuristic"]) +@pytest.mark.parametrize("prompt", prompts) +@pytest.mark.parametrize("max_new_tokens", [20, 15]) +def test_greedy_decoding(model_fixture, stop_criteria, prompt, max_new_tokens): + + generation_config = dict( + num_beam_groups=2, + num_beams=2 * 3, + diversity_penalty=1.0, + num_return_sequences=2 * 3, + max_new_tokens=max_new_tokens, + stop_criteria=stop_criteria, + ) + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) + diff --git a/src/tests/python_tests/test_greedy.py b/src/tests/python_tests/test_greedy.py deleted file mode 100644 index f33909721b..0000000000 --- a/src/tests/python_tests/test_greedy.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (C) 2023-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -def test_tiny_llama(): - from transformers import AutoTokenizer, AutoModelForCausalLM - - tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - - max_new_tokens = 32 - prompt = 'table is made of' - - encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True) - hf_encoded_output = model.generate(encoded_prompt, max_new_tokens=max_new_tokens, do_sample=False) - hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) - print(f'hf_output: {hf_output}') - - import sys - sys.path.append('src/python/openvino_genai/') - import py_generate_pipeline as genai - - pipe = genai.LLMPipeline('text_generation/causal_lm/TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/') - ov_output = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False) - print(f'ov_output: {ov_output}') - - assert hf_output == ov_output - -if __name__ == '__main__': - test_tiny_llama()