Skip to content

Commit

Permalink
added support for special tokens as stop sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Apr 20, 2024
1 parent b01820d commit 3170284
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ static size_t mem_per_token = 0;
static std::vector<float> logits;
static std::vector<int> smartcontext;
static std::vector<std::string> stop_sequence;
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
static std::vector<std::string> banned_tokens;
static std::vector<int> banned_token_ids;
static std::vector<llama_token_data> top_picks;
Expand Down Expand Up @@ -158,25 +159,40 @@ static std::string FileFormatTokenizeID(int id, FileFormat file_format)
}
}

static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format)
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format, bool add_bos=true)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos);
}
else if (file_format == FileFormat::GGML)
{
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos);
}
else if (file_format == FileFormat::GGJT_3)
{
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, add_bos);
}
else
{
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true, true);
if(add_bos)
{
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
if(output_tokens.size()==0)
{
output_tokens.push_back(bostoadd);
}
else
{
if(output_tokens[0]!=bostoadd)
{
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
}
}
}
}
}
else
Expand Down Expand Up @@ -1578,12 +1594,26 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
concat_output_mtx.unlock();
last_stop_reason = stop_reason::OUT_OF_TOKENS;
stop_sequence.clear();
special_stop_sequence.clear();
for(int x=0;x<stop_token_max;++x)
{
std::string stopper = inputs.stop_sequence[x];
if(stopper!="")
{
stop_sequence.push_back(stopper);

//if it tokenizes to a single token, AND it's a single non-printable special token, use that
std::vector<int> tmp;
TokenizeString(stopper, tmp, file_format, false);
if(tmp.size()==1) //tokenizes to exactly 1 special token
{
int specialid = tmp[0];
std::string tokenizedstr = FileFormatTokenizeID(specialid, file_format);
if(tokenizedstr=="") //must NOT have a text representation
{
special_stop_sequence.push_back(specialid);
}
}
}
}

Expand Down Expand Up @@ -2217,6 +2247,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
}

for (const auto &matched : special_stop_sequence)
{
if(id==matched)
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
{
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
}
remaining_tokens = 0;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
break;
}
}

for (const auto &matched : stop_sequence)
{
if (concat_output.find(matched) != std::string::npos)
Expand Down

0 comments on commit 3170284

Please sign in to comment.