Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed May 16, 2024
1 parent e7fa974 commit 93be036
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 75 deletions.
125 changes: 79 additions & 46 deletions src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,73 @@

#include "generation_config_helper.hpp"

namespace {
template <typename>
struct json_type_traits {};

template <>
struct json_type_traits<int> { static constexpr auto json_value_t = nlohmann::json::value_t::number_integer; };

template <>
struct json_type_traits<int64_t> { static constexpr auto json_value_t = nlohmann::json::value_t::number_integer; };

template <>
struct json_type_traits<size_t> { static constexpr auto json_value_t = nlohmann::json::value_t::number_unsigned; };

template <>
struct json_type_traits<float> { static constexpr auto json_value_t = nlohmann::json::value_t::number_float; };

template <>
struct json_type_traits<std::string> { static constexpr auto json_value_t = nlohmann::json::value_t::string; };

template <>
struct json_type_traits<bool> { static constexpr auto json_value_t = nlohmann::json::value_t::boolean; };

template <typename T>
void read_json_param(const nlohmann::json& data, const std::string& name, T& param) {
if (data.contains(name) && data[name].type() == json_type_traits<T>::json_value_t) {
param = data[name];
}
}

template <typename T>
void read_anymap_param(const ov::AnyMap& config_map, const std::string& name, T& param) {
if (config_map.count(name)) {
param = config_map.at(name).as<T>();
}
}

} // namespace


namespace ov {

GenerationConfig::GenerationConfig(std::string json_path) {
std::ifstream f(json_path);
OPENVINO_ASSERT(f.is_open(), "Failed to open '" + json_path + "' with generation config");

nlohmann::json data = nlohmann::json::parse(f);

if (data.contains("max_new_tokens")) max_new_tokens = data["max_new_tokens"];
if (data.contains("max_length")) max_length = data["max_length"];
read_json_param(data, "max_new_tokens", max_new_tokens);
read_json_param(data, "max_length", max_length);
// note that ignore_eos is not present in HF GenerationConfig
if (data.contains("num_beam_groups")) num_beam_groups = data["num_beam_groups"];
if (data.contains("num_beams")) num_beams = data["num_beams"];
if (data.contains("diversity_penalty")) diversity_penalty = data["diversity_penalty"];
if (data.contains("length_penalty")) length_penalty = data["length_penalty"];
if (data.contains("num_return_sequences")) num_return_sequences = data["num_return_sequences"];
if (data.contains("no_repeat_ngram_size")) no_repeat_ngram_size = data["no_repeat_ngram_size"];
read_json_param(data, "num_beam_groups", num_beam_groups);
read_json_param(data, "num_beams", num_beams);
read_json_param(data, "diversity_penalty", diversity_penalty);
read_json_param(data, "length_penalty", length_penalty);
read_json_param(data, "num_return_sequences", num_return_sequences);
read_json_param(data, "no_repeat_ngram_size", no_repeat_ngram_size);
// stop_criteria will be processed below
if (data.contains("temperature")) temperature = data["temperature"];
if (data.contains("top_p")) top_p = data["top_p"];
if (data.contains("top_k")) top_k = data["top_k"];
if (data.contains("do_sample")) do_sample = data["do_sample"];
if (data.contains("repetition_penalty")) repetition_penalty = data["repetition_penalty"];
if (data.contains("pad_token_id")) pad_token_id = data["pad_token_id"];
if (data.contains("bos_token_id")) bos_token_id = data["bos_token_id"];

if (data.contains("eos_token_id") && data["eos_token_id"].type() == nlohmann::json::value_t::number_integer) {
// todo: qwen contains several eos_token_id
eos_token_id = data["eos_token_id"];
}

if (data.contains("bos_token")) bos_token = data["bos_token"];
if (data.contains("eos_token")) eos_token = data["eos_token"];
read_json_param(data, "temperature", temperature);
read_json_param(data, "top_p", top_p);
read_json_param(data, "top_k", top_k);
read_json_param(data, "do_sample", do_sample);
read_json_param(data, "repetition_penalty", repetition_penalty);
read_json_param(data, "pad_token_id", pad_token_id);
read_json_param(data, "bos_token_id", bos_token_id);
read_json_param(data, "eos_token_id", eos_token_id);
read_json_param(data, "bos_token", bos_token);
read_json_param(data, "eos_token", eos_token);

if (data.contains("early_stopping")) {
auto field_type = data["early_stopping"].type();
Expand All @@ -59,28 +93,27 @@ GenerationConfig::GenerationConfig(std::string json_path) {

GenerationConfig GenerationConfigHelper::anymap_to_generation_config(const ov::AnyMap& config_map) {
GenerationConfig config = m_config;

if (config_map.count("max_new_tokens")) config.max_new_tokens = config_map.at("max_new_tokens").as<size_t>();
if (config_map.count("max_length")) config.max_length = config_map.at("max_length").as<size_t>();
if (config_map.count("ignore_eos")) config.ignore_eos = config_map.at("ignore_eos").as<bool>();
if (config_map.count("num_beam_groups")) config.num_beam_groups = config_map.at("num_beam_groups").as<size_t>();
if (config_map.count("num_beams")) config.num_beams = config_map.at("num_beams").as<size_t>();
if (config_map.count("diversity_penalty")) config.diversity_penalty = config_map.at("diversity_penalty").as<float>();
if (config_map.count("length_penalty")) config.length_penalty = config_map.at("length_penalty").as<float>();
if (config_map.count("num_return_sequences")) config.num_return_sequences = config_map.at("num_return_sequences").as<size_t>();
if (config_map.count("no_repeat_ngram_size")) config.no_repeat_ngram_size = config_map.at("no_repeat_ngram_size").as<size_t>();
if (config_map.count("stop_criteria")) config.stop_criteria = config_map.at("stop_criteria").as<StopCriteria>();
if (config_map.count("temperature")) config.temperature = config_map.at("temperature").as<float>();
if (config_map.count("top_p")) config.top_p = config_map.at("top_p").as<float>();
if (config_map.count("top_k")) config.top_k = config_map.at("top_k").as<int>();
if (config_map.count("do_sample")) config.do_sample = config_map.at("do_sample").as<bool>();
if (config_map.count("repetition_penalty")) config.repetition_penalty = config_map.at("repetition_penalty").as<float>();
if (config_map.count("pad_token_id")) config.pad_token_id = config_map.at("pad_token_id").as<int64_t>();
if (config_map.count("bos_token_id")) config.bos_token_id = config_map.at("bos_token_id").as<int64_t>();
if (config_map.count("eos_token_id")) config.eos_token_id = config_map.at("eos_token_id").as<int64_t>();
if (config_map.count("bos_token")) config.bos_token = config_map.at("bos_token").as<std::string>();
if (config_map.count("eos_token")) config.eos_token = config_map.at("eos_token").as<std::string>();

read_anymap_param(config_map, "max_new_tokens", config.max_new_tokens);
read_anymap_param(config_map, "max_length", config.max_length);
read_anymap_param(config_map, "ignore_eos", config.ignore_eos);
read_anymap_param(config_map, "num_beam_groups", config.num_beam_groups);
read_anymap_param(config_map, "num_beams", config.num_beams);
read_anymap_param(config_map, "diversity_penalty", config.diversity_penalty);
read_anymap_param(config_map, "length_penalty", config.length_penalty);
read_anymap_param(config_map, "num_return_sequences", config.num_return_sequences);
read_anymap_param(config_map, "no_repeat_ngram_size", config.no_repeat_ngram_size);
read_anymap_param(config_map, "stop_criteria", config.stop_criteria);
read_anymap_param(config_map, "temperature", config.temperature);
read_anymap_param(config_map, "top_p", config.top_p);
read_anymap_param(config_map, "top_k", config.top_k);
read_anymap_param(config_map, "do_sample", config.do_sample);
read_anymap_param(config_map, "repetition_penalty", config.repetition_penalty);
read_anymap_param(config_map, "pad_token_id", config.pad_token_id);
read_anymap_param(config_map, "bos_token_id", config.bos_token_id);
read_anymap_param(config_map, "eos_token_id", config.eos_token_id);
read_anymap_param(config_map, "bos_token", config.bos_token);
read_anymap_param(config_map, "eos_token", config.eos_token);

return config;
}

Expand Down
102 changes: 102 additions & 0 deletions src/tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest

model_ids = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/"),
# ("meta-llama/Llama-2-7b-chat-hf", "Llama-2-7b-chat-hf/pytorch/dldt/FP16/"),
# ("microsoft/phi-1_5", "phi-1_5/"),
# ("google/gemma-2b-it", "gemma-2b-it/pytorch/dldt/FP16/"),
]


def run_hf_ov_genai_comparison(model_fixture, generation_config, prompt):
model_id, path, tokenizer, model = model_fixture

generation_config_hf = generation_config.copy()
# in OpenVINO GenAI this parameter is called stop_criteria,
# while in HF it's called early_stopping.
# HF values True, False and "never" correspond to OV GenAI values "early", "heuristic" and "never"
if generation_config_hf.get('stop_criteria'):
generation_config_hf['early_stopping'] = stop_criteria_map()[generation_config_hf.pop('stop_criteria')]

encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True)
hf_encoded_output = model.generate(encoded_prompt, **generation_config_hf)
hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:])

import sys
# sys.path.append('../../src/python/openvino_genai/')
sys.path.append('/home/epavel/devel/openvino.genai/src/python/openvino_genai/')
import py_generate_pipeline as genai

pipe = genai.LLMPipeline(path)
ov_output = pipe.generate(prompt, **generation_config)

if hf_output != ov_output:
print(f'hf_output: {hf_output}')
print(f'ov_output: {ov_output}')

assert hf_output == ov_output


def stop_criteria_map():
return {"never": "never", "early": True, "heuristic": False}


@pytest.fixture(scope="module", params=model_ids)
def model_fixture(request):
model_id, path = request.param
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return model_id, path, tokenizer, model


test_cases = [
(dict(max_new_tokens=20, do_sample=False), 'table is made of'), # generation_config, prompt
(dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'),
(dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'),
(dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'),
(dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'),
(dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'The Sun is yellow because'),
(dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'),
]
@pytest.mark.parametrize("generation_config,prompt", test_cases)
def test_greedy_decoding(model_fixture, generation_config, prompt):
run_hf_ov_genai_comparison(model_fixture, generation_config, prompt)


prompts = ['The Sun is yellow because'] #, 'Alan Turing was a', 'table is made of']
@pytest.mark.parametrize("num_beam_groups", [2, 3])
@pytest.mark.parametrize("group_size", [5, 3])
@pytest.mark.parametrize("max_new_tokens", [20, 15])
@pytest.mark.parametrize("diversity_penalty", [1.0, 1.5])
@pytest.mark.parametrize("prompt", prompts)
def test_beam_search_decoding(model_fixture, num_beam_groups, group_size,
max_new_tokens, diversity_penalty, prompt):
generation_config = dict(
num_beam_groups=num_beam_groups,
num_beams=num_beam_groups * group_size,
diversity_penalty=diversity_penalty,
num_return_sequences=num_beam_groups * group_size,
max_new_tokens=max_new_tokens,
)
run_hf_ov_genai_comparison(model_fixture, generation_config, prompt)


@pytest.mark.parametrize("stop_criteria", ["never", "early", "heuristic"])
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.parametrize("max_new_tokens", [20, 15])
def test_greedy_decoding(model_fixture, stop_criteria, prompt, max_new_tokens):

generation_config = dict(
num_beam_groups=2,
num_beams=2 * 3,
diversity_penalty=1.0,
num_return_sequences=2 * 3,
max_new_tokens=max_new_tokens,
stop_criteria=stop_criteria,
)
run_hf_ov_genai_comparison(model_fixture, generation_config, prompt)

29 changes: 0 additions & 29 deletions src/tests/python_tests/test_greedy.py

This file was deleted.

0 comments on commit 93be036

Please sign in to comment.