From 05bbe1f9500c0dfb728c355fa253407f9bead0b0 Mon Sep 17 00:00:00 2001 From: rogmann Date: Sun, 7 Jul 2024 23:42:13 +0200 Subject: [PATCH] Support for split UTF-8 sequences. --- Llama3.java | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/Llama3.java b/Llama3.java index 0a7525b..3922411 100755 --- a/Llama3.java +++ b/Llama3.java @@ -1092,6 +1092,12 @@ class Tokenizer { private final Vocabulary vocabulary; private final Map, Integer> merges; private final Map specialTokens; + /** buffer to store incomplete UTF-8 sequence */ + private final byte[] bufUtf8 = new byte[4]; + /** index in UTF-8 buffer */ + private int bufUtf8Index = 0; + /** number of expected bytes in UTF-8 buffer */ + private int bufUtf8Size = 0; public String regexPattern() { if (compiledPattern == null) { @@ -1324,11 +1330,37 @@ public List encodeAsList(String text) { public String decode(List tokens) { String decoded = decodeImpl(tokens); int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); - byte[] rawBytes = new byte[decodedBytesAsInts.length]; + byte[] rawBytes = new byte[decodedBytesAsInts.length + 3]; + int indexRawByte = 0; for (int i = 0; i < decoded.length(); i++) { - rawBytes[i] = (byte) decodedBytesAsInts[i]; + byte b = (byte) decodedBytesAsInts[i]; + if ((b & 0b11100000) == 0b11000000 && bufUtf8Index == 0) { + bufUtf8Size = 2; // Start of UTF-8 two bytes sequence. + bufUtf8[bufUtf8Index++] = b; + continue; + } + if ((b & 0b11110000) == 0b11100000 && bufUtf8Index == 0) { + bufUtf8Size = 3; // Start of UTF-8 three bytes sequence. + bufUtf8[bufUtf8Index++] = b; + continue; + } + if ((b & 0b11111000) == 0b11110000 && bufUtf8Index == 0) { + bufUtf8Size = 4; // Start of UTF-8 four bytes sequence. + bufUtf8[bufUtf8Index++] = b; + continue; + } + if (bufUtf8Index > 0) { + bufUtf8[bufUtf8Index++] = b; + if (bufUtf8Index == bufUtf8Size) { + System.arraycopy(bufUtf8, 0, rawBytes, indexRawByte, bufUtf8Size); + indexRawByte += bufUtf8Size; + bufUtf8Index = 0; + } + continue; + } + rawBytes[indexRawByte++] = b; } - return new String(rawBytes, StandardCharsets.UTF_8); + return new String(rawBytes, 0, indexRawByte, StandardCharsets.UTF_8); } }