Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
snowyu committed Feb 6, 2024
1 parent a4ad764 commit e38592d
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 0 deletions.
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ option(GGML_FMA "ggml: enable FMA"
option(GGML_CUBLAS "ggml: use cuBLAS" OFF)
option(GGML_METAL "ggml: use Metal" OFF)

option(BERT_BUILD_TESTS "bert: Build tests" ON)

#
# Compile flags
#
Expand Down Expand Up @@ -93,3 +95,8 @@ add_subdirectory(src)

# for shared library
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)

if (BERT_BUILD_TESTS)
include(CTest)
add_subdirectory(tests)
endif ()
24 changes: 24 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
include_directories(${CMAKE_SOURCE_DIR}/src)

# add_executable(test_tokenizer test_tokenizer.cpp)
# target_link_libraries(test_tokenizer PRIVATE bert ggml)

set(TEST_MODEL_NAME "bge-large-zh-v1.5")

function(bert_build_executable source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_executable(${TEST_TARGET} ${source})
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE bert ggml)
endfunction()

function(bert_test_executable name source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_test(NAME "Generate_HF_tokens" COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_hf_tokenizer.py ${TEST_MODEL_NAME})
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
set_property(TEST ${name} PROPERTY LABELS "main")
endfunction()


bert_build_executable(test_tokenizer.cpp)
bert_test_executable (test_tokenizer test_tokenizer.cpp -m ${CMAKE_CURRENT_SOURCE_DIR}/../models/${TEST_MODEL_NAME}/bge-large-zh-v1.5-q4_1.gguf)
19 changes: 19 additions & 0 deletions tests/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env bash

set -e

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
MODEL_NAME=${1:-bge-large-zh-v1.5}
MODEL_DIR=$(realpath "$SCRIPT_DIR/../models/$MODEL_NAME")

if [ ! -d "$MODEL_DIR" ]; then
python3 $SCRIPT_DIR/../models/download-repo.py $MODEL_NAME
fi

if [ ! -d "$MODEL_DIR/ggml-model-q4_1.gguf" ]; then
$SCRIPT_DIR/../models/run_conversions.sh $MODEL_NAME q4_1
fi

python3 $SCRIPT_DIR/test_hf_tokenizer.py $MODEL_DIR

$SCRIPT_DIR/../build/bin/test_tokenizer -m $MODEL_DIR/ggml-model-q4_1.gguf
39 changes: 39 additions & 0 deletions tests/test_hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from ast import arg
from transformers import AutoTokenizer, AutoModel
import argparse
import os

SCRIPT_PATH=os.path.dirname(os.path.realpath(__file__))

def main(args):
# tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
if "/" in args.model_name:
tokenizer_name = args.model_name
elif "MiniLM" in args.model_name:
tokenizer_name = f"sentence-transformers/{args.model_name}"
elif "bge-" in args.model_name:
tokenizer_name = f"BAAI/{args.model_name}"
else:
raise ValueError(f"Unknown model name: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

with open(SCRIPT_PATH + "/test_prompts.txt", "r", encoding="utf-8") as f:
inps = f.readlines()
inps = list(map(lambda x: x.strip(), inps))

print("Using tokenizer:", tokenizer_name)
output = []
for inp in inps:
oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist()
output.append(",".join([str(x) for x in oup]))
for token in oup:
print(f"{token} <--> {tokenizer.decode([token])}")

with open(SCRIPT_PATH + "/hf_tokenized_ids.txt", "w", encoding="utf-8") as f:
f.write("\n".join(output))

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Download original repo files')
parser.add_argument('model_name', type=str, help='Name of the repo')
args = parser.parse_args()
main(args)
8 changes: 8 additions & 0 deletions tests/test_prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
hello world
i'm going to the store to buy 3 apples and a banana! you're welcome to come along if you'd like. the time is 2:30 p.m. and it's partly cloudy outside. i'll be back soon, so don't go anywhere.
"5 2 + 3 * 4 -"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == \'+\' ? a + b : operator == \'-\' ? a - b : operator == \'*\' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatepostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - \'0\'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatepostfix(input);
你好,世界!
こんにちは、世界!
1231 2431431
你好我是gpt
然而,分音符号(diaeresis)和变音符号(umlaut)在一些情况下也可以被泛称为 "accent",这是因为它们都是附加在字母上的符号,用于改变字母的原始发音。
187 changes: 187 additions & 0 deletions tests/test_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#include "bert.h"
#include "ggml.h"

#include <unistd.h>
#include <map>
#include <algorithm>
#include <stdio.h>
#include <string>
#include <vector>
#include <fstream>
#include <sstream>
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_COLOR_GREEN "\x1b[32m"


std::vector<std::string> txt2list(const std::string& filename) {
std::ifstream file(filename);
std::vector<std::string> all_lines;

if (!file.is_open()) {
printf("can not open file: %s\n", filename.c_str());
return all_lines;
}

std::string line;
while (std::getline(file, line)) {
all_lines.push_back(line);
}

file.close();
return all_lines;
}

std::vector<std::vector<int>> read_expected_tokenids(const std::string& filename) {
std::ifstream file(filename);
std::vector<std::vector<int>> all_numbers;

if (!file.is_open()) {
printf("can not open file: %s\n", filename.c_str());
return all_numbers;
}


std::string line;
while (std::getline(file, line)) {
std::vector<int> line_numbers;
std::istringstream iss(line);
std::string number_str;

while (std::getline(iss, number_str, ',')) {
line_numbers.push_back(std::stoi(number_str));
}

all_numbers.push_back(line_numbers);
}

file.close();
return all_numbers;
}

void tokenizer_test(bert_ctx * ctx, const std::string& input, const bert_tokens& expected) {
int N = bert_n_max_tokens(ctx);
bert_tokens result = bert_tokenize(ctx, input, N);
int n_tokens;

if (result != expected) {
printf("tokenizer test failed: '%.*s'\n", 16000, input.data());

printf("[");
for (auto& tok : result) {
printf("%d, ", tok);
}
printf("]\n");

for (size_t i = 0; i < result.size(); i++) {
bert_token a = expected[std::min(i, expected.size()-1)];
bert_token b = result[i];
const char *color_start = (a == b) ? ANSI_COLOR_GREEN : ANSI_COLOR_RED;
const char *color_end = ANSI_COLOR_RESET;

printf("%s%d -> %s : %d -> %s%s\n", color_start, a, bert_vocab_id_to_token(ctx, a), b, bert_vocab_id_to_token(ctx, b), color_end);
}
} else {
printf("Success '%.*s...'\n", 16, input.data());
}
}


struct bert_params
{
int32_t n_threads = 6;
const char* model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";
const char* prompt = "test prompt";
int32_t batch_size = 32;
bool use_cpu = false;
};

void bert_print_usage(char **argv, const bert_params &params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model);
fprintf(stderr, " batch size to use when executing model\n");
fprintf(stderr, " -c, --cpu use CPU backend (default: use CUDA if available)\n");
fprintf(stderr, "\n");
}

bool bert_params_parse(int argc, char **argv, bert_params &params) {
for (int i = 1; i < argc; i++)
{
std::string arg = argv[i];

if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-c" || arg == "--cpu") {
params.use_cpu = true;
} else if (arg == "-h" || arg == "--help") {
bert_print_usage(argv, params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
bert_print_usage(argv, params);
exit(0);
}
}

return true;
}

int main(int argc, char ** argv) {

bert_params params;
params.model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";

if (bert_params_parse(argc, argv, params) == false) {
return 1;
}


bert_ctx * bctx;

// load the model
{
if ((bctx = bert_load_from_file(params.model, params.use_cpu)) == nullptr) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model);
return 1;
}
}
std::string dir = params.model;
std::size_t i = dir.rfind("/models/");
if (i != std::string::npos) {
dir.resize(i);
} else {
dir = ".";
}

auto expected = read_expected_tokenids(dir + "/tests/hf_tokenized_ids.txt");
auto prompts = txt2list(dir + "/tests/test_prompts.txt");

if (expected.size() == 0 || prompts.size() == 0) {
printf("failed to read test data\n");
return 1;
}

if (expected.size() != prompts.size()) {
printf("test data size mismatch\n");
return 1;
}

// tokenizer tests:
for (size_t i = 0; i < prompts.size(); i++) {
tokenizer_test(bctx, prompts[i], expected[i]);
}

// tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102});
// tokenizer_test(bctx, "Québec", {101, 5447, 102});
// tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102});
// tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102});
// tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102});

// tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102});
// tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102});
// tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102});
}

0 comments on commit e38592d

Please sign in to comment.