Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests #9

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ option(GGML_FMA "ggml: enable FMA"
option(GGML_CUBLAS "ggml: use cuBLAS" OFF)
option(GGML_METAL "ggml: use Metal" OFF)

option(BERT_BUILD_TESTS "bert: Build tests" ON)

#
# Compile flags
#
Expand Down Expand Up @@ -94,3 +96,8 @@ install(TARGETS ggml LIBRARY DESTINATION bert_cpp)

# add bert
add_subdirectory(src)

if (BERT_BUILD_TESTS)
include(CTest)
add_subdirectory(tests)
endif ()
24 changes: 24 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
include_directories(${CMAKE_SOURCE_DIR}/src)

# add_executable(test_tokenizer test_tokenizer.cpp)
# target_link_libraries(test_tokenizer PRIVATE bert ggml)

set(TEST_MODEL_NAME "bge-large-zh-v1.5")

function(bert_build_executable source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_executable(${TEST_TARGET} ${source})
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE bert ggml)
endfunction()

function(bert_test_executable name source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_test(NAME "Generate_HF_tokens" COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_hf_tokenizer.py ${TEST_MODEL_NAME})
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
set_property(TEST ${name} PROPERTY LABELS "main")
endfunction()


bert_build_executable(test_tokenizer.cpp)
bert_test_executable (test_tokenizer test_tokenizer.cpp -m ${CMAKE_CURRENT_SOURCE_DIR}/../models/${TEST_MODEL_NAME}/bge-large-zh-v1.5-q4_1.gguf)
19 changes: 19 additions & 0 deletions tests/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env bash

set -e

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
MODEL_NAME=${1:-bge-large-zh-v1.5}
MODEL_DIR=$(realpath "$SCRIPT_DIR/../models/$MODEL_NAME")

if [ ! -d "$MODEL_DIR" ]; then
python3 $SCRIPT_DIR/../models/download-repo.py $MODEL_NAME
fi

if [ ! -d "$MODEL_DIR/ggml-model-q4_1.gguf" ]; then
$SCRIPT_DIR/../models/run_conversions.sh $MODEL_NAME q4_1
fi

python3 $SCRIPT_DIR/test_hf_tokenizer.py $MODEL_DIR

$SCRIPT_DIR/../build/bin/test_tokenizer -m $MODEL_DIR/ggml-model-q4_1.gguf
39 changes: 39 additions & 0 deletions tests/test_hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from ast import arg
from transformers import AutoTokenizer, AutoModel
import argparse
import os

SCRIPT_PATH=os.path.dirname(os.path.realpath(__file__))

def main(args):
# tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
if "/" in args.model_name:
tokenizer_name = args.model_name
elif "MiniLM" in args.model_name:
tokenizer_name = f"sentence-transformers/{args.model_name}"
elif "bge-" in args.model_name:
tokenizer_name = f"BAAI/{args.model_name}"
else:
raise ValueError(f"Unknown model name: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

with open(SCRIPT_PATH + "/test_prompts.txt", "r", encoding="utf-8") as f:
inps = f.readlines()
inps = list(map(lambda x: x.strip(), inps))

print("Using tokenizer:", tokenizer_name)
output = []
for inp in inps:
oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist()
output.append(",".join([str(x) for x in oup]))
for token in oup:
print(f"{token} <--> {tokenizer.decode([token])}")

with open(SCRIPT_PATH + "/hf_tokenized_ids.txt", "w", encoding="utf-8") as f:
f.write("\n".join(output))

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Download original repo files')
parser.add_argument('model_name', type=str, help='Name of the repo')
args = parser.parse_args()
main(args)
8 changes: 8 additions & 0 deletions tests/test_prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
hello world
i'm going to the store to buy 3 apples and a banana! you're welcome to come along if you'd like. the time is 2:30 p.m. and it's partly cloudy outside. i'll be back soon, so don't go anywhere.
"5 2 + 3 * 4 -"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == \'+\' ? a + b : operator == \'-\' ? a - b : operator == \'*\' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatepostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - \'0\'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatepostfix(input);
你好,世界!
こんにちは、世界!
1231 2431431
你好我是gpt
然而,分音符号(diaeresis)和变音符号(umlaut)在一些情况下也可以被泛称为 "accent",这是因为它们都是附加在字母上的符号,用于改变字母的原始发音。
193 changes: 193 additions & 0 deletions tests/test_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#ifdef NDEBUG
#undef NDEBUG
#endif

#include "bert.h"
#include "ggml.h"

#include <cassert>
#include <unistd.h>
#include <map>
#include <algorithm>
#include <stdio.h>
#include <string>
#include <vector>
#include <fstream>
#include <sstream>
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_COLOR_GREEN "\x1b[32m"


std::vector<std::string> txt2list(const std::string& filename) {
std::ifstream file(filename);
std::vector<std::string> all_lines;

if (!file.is_open()) {
printf("can not open file: %s\n", filename.c_str());
return all_lines;
}

std::string line;
while (std::getline(file, line)) {
all_lines.push_back(line);
}

file.close();
return all_lines;
}

std::vector<std::vector<int>> read_expected_tokenids(const std::string& filename) {
std::ifstream file(filename);
std::vector<std::vector<int>> all_numbers;

if (!file.is_open()) {
printf("can not open file: %s\n", filename.c_str());
return all_numbers;
}


std::string line;
while (std::getline(file, line)) {
std::vector<int> line_numbers;
std::istringstream iss(line);
std::string number_str;

while (std::getline(iss, number_str, ',')) {
line_numbers.push_back(std::stoi(number_str));
}

all_numbers.push_back(line_numbers);
}

file.close();
return all_numbers;
}

void tokenizer_test(bert_ctx * ctx, const std::string& input, const bert_tokens& expected) {
int N = bert_n_max_tokens(ctx);
bert_tokens result = bert_tokenize(ctx, input, N);
int n_tokens;

if (result != expected) {
printf("tokenizer test failed: '%.*s'\n", 16000, input.data());

printf("[");
for (auto& tok : result) {
printf("%d, ", tok);
}
printf("]\n");

for (size_t i = 0; i < result.size(); i++) {
bert_token a = expected[std::min(i, expected.size()-1)];
bert_token b = result[i];
const char *color_start = (a == b) ? ANSI_COLOR_GREEN : ANSI_COLOR_RED;
const char *color_end = ANSI_COLOR_RESET;

printf("%s%d -> %s : %d -> %s%s\n", color_start, a, bert_vocab_id_to_token(ctx, a), b, bert_vocab_id_to_token(ctx, b), color_end);
}
} else {
printf("Success '%.*s...'\n", 16, input.data());
}
assert(result == expected);
}


struct bert_params
{
int32_t n_threads = 6;
const char* model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";
const char* prompt = "test prompt";
int32_t batch_size = 32;
bool use_cpu = false;
};

void bert_print_usage(char **argv, const bert_params &params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model);
fprintf(stderr, " batch size to use when executing model\n");
fprintf(stderr, " -c, --cpu use CPU backend (default: use CUDA if available)\n");
fprintf(stderr, "\n");
}

bool bert_params_parse(int argc, char **argv, bert_params &params) {
for (int i = 1; i < argc; i++)
{
std::string arg = argv[i];

if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-c" || arg == "--cpu") {
params.use_cpu = true;
} else if (arg == "-h" || arg == "--help") {
bert_print_usage(argv, params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
bert_print_usage(argv, params);
exit(0);
}
}

return true;
}

int main(int argc, char ** argv) {

bert_params params;
params.model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";

if (bert_params_parse(argc, argv, params) == false) {
return 1;
}


bert_ctx * bctx;

// load the model
{
if ((bctx = bert_load_from_file(params.model, params.use_cpu)) == nullptr) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model);
return 1;
}
}
std::string dir = params.model;
std::size_t i = dir.rfind("/models/");
if (i != std::string::npos) {
dir.resize(i);
} else {
dir = ".";
}

auto expected = read_expected_tokenids(dir + "/tests/hf_tokenized_ids.txt");
auto prompts = txt2list(dir + "/tests/test_prompts.txt");

if (expected.size() == 0 || prompts.size() == 0) {
printf("failed to read test data\n");
return 1;
}

if (expected.size() != prompts.size()) {
printf("test data size mismatch\n");
return 1;
}

// tokenizer tests:
for (size_t i = 0; i < prompts.size(); i++) {
tokenizer_test(bctx, prompts[i], expected[i]);
}

// tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102});
// tokenizer_test(bctx, "Québec", {101, 5447, 102});
// tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102});
// tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102});
// tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102});

// tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102});
// tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102});
// tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102});
}