Skip to content

Commit 81122c4

Browse files
authored
Introducing WordPiece and Bert tokenizers (#7275)
* Introducing WordPiece and Bert tokenizers * Fix corner case in WordPiece
1 parent 32bac5e commit 81122c4

File tree

11 files changed

+2618
-16
lines changed

11 files changed

+2618
-16
lines changed

THIRD-PARTY-NOTICES.TXT

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
152152
See the License for the specific language governing permissions and
153153
limitations under the License.
154154

155+
License notice for WordPiece and Bert tokenizers
156+
------------------------------------------------
157+
158+
https://github.com/huggingface/transformers/blob/8e3e145b427196e014f37aa42ba890b9bc94275e/src/transformers/models/bert/tokenization_bert.py#L2
159+
160+
Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
161+
162+
Licensed under the Apache License, Version 2.0 (the "License");
163+
you may not use this file except in compliance with the License.
164+
You may obtain a copy of the License at
165+
166+
http://www.apache.org/licenses/LICENSE-2.0
167+
168+
Unless required by applicable law or agreed to in writing, software
169+
distributed under the License is distributed on an "AS IS" BASIS,
170+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
171+
See the License for the specific language governing permissions and
172+
limitations under the License.
173+
155174
License notice for BitUtility
156175
------------------------------------------
157176

src/Microsoft.ML.Tokenizers/EncodedToken.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Microsoft.ML.Tokenizers
1010
/// Represent the token produced from the tokenization process containing the token substring,
1111
/// the id associated to the token substring, and the offset mapping to the original string.
1212
/// </summary>
13-
public readonly struct EncodedToken
13+
public readonly struct EncodedToken : IEquatable<EncodedToken>
1414
{
1515
/// <summary>
1616
/// Gets the Id value associated to the token.
@@ -39,5 +39,8 @@ public EncodedToken(int id, string value, Range offset)
3939
Offset = offset;
4040
Value = value;
4141
}
42+
43+
/// inherited
44+
public bool Equals(EncodedToken other) => Id == other.Id && Value == other.Value && Offset.Equals(other.Offset);
4245
}
4346
}

src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private set
8787
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
8888
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
8989
public static BpeTokenizer Create(string vocabFile, string? mergesFile)
90-
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
90+
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
9191

9292
/// <summary>
9393
/// Create a new Bpe tokenizer object to use for text encoding.
@@ -131,7 +131,7 @@ public static BpeTokenizer Create(
131131
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
132132
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
133133
public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream)
134-
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
134+
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
135135

136136
/// <summary>
137137
/// Create a new Bpe tokenizer object to use for text encoding.
@@ -225,7 +225,7 @@ private BpeTokenizer(
225225
FuseUnknownTokens = fuseUnknownTokens;
226226
ContinuingSubwordPrefix = continuingSubwordPrefix;
227227
EndOfWordSuffix = endOfWordSuffix;
228-
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpace(); // Default to WhiteSpace pre-tokenizer
228+
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWordOrNonWordPreTokenizer(); // Default to WordOrNonWord pre-tokenizer
229229
_normalizer = normalizer;
230230

231231
_vocab = vocab ?? new Dictionary<StringSpanOrdinalKey, int>();

src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs

Lines changed: 729 additions & 0 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs

Lines changed: 858 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Buffers;
7+
using System.Diagnostics;
8+
using System.Globalization;
9+
using System.Runtime.CompilerServices;
10+
using System.Runtime.InteropServices;
11+
using System.Text;
12+
13+
namespace Microsoft.ML.Tokenizers
14+
{
15+
/// <summary>
16+
/// Normalizer that performs the Bert model normalization.
17+
/// </summary>
18+
internal sealed class BertNormalizer : Normalizer
19+
{
20+
private readonly bool _doLowerCase;
21+
private readonly bool _tokenizeChineseChars;
22+
private readonly bool _stripAccents;
23+
24+
/// <summary>
25+
/// Normalize the input string.
26+
/// </summary>
27+
/// <param name="original">The input string to normalize.</param>
28+
/// <returns>The normalized string.</returns>
29+
public override string Normalize(string original)
30+
{
31+
if (string.IsNullOrEmpty(original))
32+
{
33+
return string.Empty;
34+
}
35+
36+
if (_stripAccents)
37+
{
38+
original = original.Normalize(NormalizationForm.FormD);
39+
}
40+
41+
Span<char> casingBuffer = stackalloc char[10];
42+
char[] buffer = ArrayPool<char>.Shared.Rent(original.Length);
43+
int index = 0;
44+
45+
for (int i = 0; i < original.Length; i++)
46+
{
47+
char c = original[i];
48+
49+
if (c == '\u0000' || c == '\uFFFD')
50+
{
51+
continue;
52+
}
53+
54+
int inc = 0;
55+
int codePoint = (int)c;
56+
if (char.IsHighSurrogate(c) && i + 1 < original.Length && char.IsLowSurrogate(original[i + 1]))
57+
{
58+
codePoint = char.ConvertToUtf32(c, original[i + 1]);
59+
inc = 1;
60+
}
61+
62+
UnicodeCategory category = CharUnicodeInfo.GetUnicodeCategory(original, i);
63+
64+
if (category == UnicodeCategory.Control)
65+
{
66+
i += inc;
67+
continue;
68+
}
69+
70+
if (category == UnicodeCategory.SpaceSeparator)
71+
{
72+
InsertChar(ref buffer, ref index, ' ');
73+
i += inc;
74+
continue;
75+
}
76+
77+
if (_stripAccents && category is UnicodeCategory.NonSpacingMark or UnicodeCategory.SpacingCombiningMark)
78+
{
79+
i += inc;
80+
continue;
81+
}
82+
83+
if (_doLowerCase && category == UnicodeCategory.UppercaseLetter)
84+
{
85+
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
86+
Debug.Assert(length > 0);
87+
88+
InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
89+
90+
i += inc;
91+
continue;
92+
}
93+
94+
if (_tokenizeChineseChars && IsChineseChar(codePoint))
95+
{
96+
InsertChar(ref buffer, ref index, ' ');
97+
InsertChar(ref buffer, ref index, c);
98+
if (inc > 0)
99+
{
100+
InsertChar(ref buffer, ref index, original[i + 1]);
101+
}
102+
InsertChar(ref buffer, ref index, ' ');
103+
104+
i += inc;
105+
continue;
106+
}
107+
108+
InsertChar(ref buffer, ref index, c);
109+
if (inc > 0)
110+
{
111+
InsertChar(ref buffer, ref index, original[i + 1]);
112+
}
113+
i += inc;
114+
}
115+
116+
string result = index == 0 ? string.Empty : new string(buffer, 0, index).Normalize(NormalizationForm.FormC);
117+
ArrayPool<char>.Shared.Return(buffer);
118+
return result;
119+
}
120+
121+
/// <summary>
122+
/// Normalize the input character span.
123+
/// </summary>
124+
/// <param name="original">The input character span to normalize.</param>
125+
/// <returns>The normalized string.</returns>
126+
public override string Normalize(ReadOnlySpan<char> original)
127+
{
128+
if (original.IsEmpty)
129+
{
130+
return string.Empty;
131+
}
132+
133+
return Normalize(original.ToString());
134+
}
135+
136+
/// <summary>
137+
/// Initializes a new instance of the <see cref="BertNormalizer"/> class.
138+
/// </summary>
139+
/// <param name="doLowerCase">Whether to lowercase the input.</param>
140+
/// <param name="tokenizeChineseChars">Whether to tokenize Chinese characters.</param>
141+
/// <param name="stripAccents">Whether to strip accents from the input.</param>
142+
public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAccents)
143+
{
144+
_doLowerCase = doLowerCase;
145+
_tokenizeChineseChars = tokenizeChineseChars;
146+
_stripAccents = stripAccents;
147+
}
148+
149+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
150+
private static void InsertChar(ref char[] buffer, ref int index, char c)
151+
{
152+
if (index >= buffer.Length)
153+
{
154+
Helpers.ArrayPoolGrow(ref buffer, index + 40);
155+
}
156+
157+
buffer[index++] = c;
158+
}
159+
160+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
161+
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
162+
{
163+
if (index + buffer.Length >= buffer.Length)
164+
{
165+
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
166+
}
167+
168+
chars.CopyTo(buffer.AsSpan(index));
169+
index += chars.Length;
170+
}
171+
172+
/// <summary>
173+
/// Checks whether CP is the codepoint of a CJK character.
174+
/// This defines a "chinese character" as anything in the CJK Unicode block:
175+
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
176+
/// </summary>
177+
/// <param name="codePoint">The codepoint to check.</param>
178+
/// <remarks>
179+
/// The CJK Unicode block is NOT all Japanese and Korean characters,
180+
/// despite its name. The modern Korean Hangul alphabet is a different block,
181+
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
182+
/// space-separated words, so they are not treated specially and handled
183+
/// like the all of the other languages.
184+
/// </remarks>
185+
/// <returns>True if the codepoint is a CJK character, false otherwise.</returns>
186+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
187+
private static bool IsChineseChar(int codePoint)
188+
{
189+
return (codePoint > 0x3400) && // Quick check to exit early if the codepoint is outside of the CJK range
190+
(((uint)(codePoint - 0x3400) <= (uint)(0x4DBF - 0x3400)) ||
191+
((uint)(codePoint - 0xF900) <= (uint)(0xFAFF - 0xF900)) ||
192+
((uint)(codePoint - 0x4E00) <= (uint)(0x9FFF - 0x4E00)) ||
193+
((uint)(codePoint - 0x20000) <= (uint)(0x2A6DF - 0x20000)) ||
194+
((uint)(codePoint - 0x2A700) <= (uint)(0x2B73F - 0x2A700)) ||
195+
((uint)(codePoint - 0x2B740) <= (uint)(0x2B81F - 0x2B740)) ||
196+
((uint)(codePoint - 0x2B820) <= (uint)(0x2CEAF - 0x2B820)) ||
197+
((uint)(codePoint - 0x2F800) <= (uint)(0x2FA1F - 0x2F800)));
198+
}
199+
}
200+
}

src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,61 @@ public abstract partial class PreTokenizer
4040
}
4141
}
4242

43-
private const string WhiteSpacePattern = /*lang=regex*/ @"\w+|[^\w\s]+";
43+
private const string WhiteSpaceOrPunctuationPattern = @"\w+|[\p{P}]";
44+
private static PreTokenizer? _whiteSpaceOrPunctuationPreTokenizer;
45+
#if NET7_0_OR_GREATER
46+
[GeneratedRegex(WhiteSpaceOrPunctuationPattern)]
47+
private static partial Regex WhiteSpaceOrPunctuationRegex();
48+
#else
49+
private static Regex WhiteSpaceOrPunctuationRegex() => new Regex(WhiteSpaceOrPunctuationPattern, RegexOptions.Compiled);
50+
#endif
51+
52+
/// <summary>
53+
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the whitespace or punctuation characters.
54+
/// </summary>
55+
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
56+
/// <returns>The pre-tokenizer that splits the text at the whitespace or punctuation characters.</returns>
57+
public static PreTokenizer CreateWhiteSpaceOrPunctuationPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
58+
{
59+
if (specialTokensEncoder is null)
60+
{
61+
// return a singleton instance of the WhiteSpace pre-tokenizer
62+
return _whiteSpaceOrPunctuationPreTokenizer ??= new RegexPreTokenizer(WhiteSpaceOrPunctuationRegex(), null);
63+
}
64+
65+
return new RegexPreTokenizer(WhiteSpaceOrPunctuationRegex(), specialTokensEncoder);
66+
}
67+
68+
private const string WordOrNonWordPattern = /*lang=regex*/ @"\w+|[^\w\s]+";
69+
private static PreTokenizer? _wordOrNonWordPreTokenizer;
70+
71+
#if NET7_0_OR_GREATER
72+
[GeneratedRegex(WordOrNonWordPattern)]
73+
private static partial Regex WordOrNonWordRegex();
74+
#else
75+
private static Regex WordOrNonWordRegex() => new Regex(WordOrNonWordPattern, RegexOptions.Compiled);
76+
#endif
77+
78+
/// <summary>
79+
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the word or non-word boundary.
80+
/// The word is a set of alphabet, numeric, and underscore characters.
81+
/// </summary>
82+
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
83+
/// <returns>The pre-tokenizer that splits the text at the word boundary.</returns>
84+
public static PreTokenizer CreateWordOrNonWordPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
85+
{
86+
if (specialTokensEncoder is null)
87+
{
88+
// return a singleton instance of the WhiteSpace pre-tokenizer
89+
return _wordOrNonWordPreTokenizer ??= new RegexPreTokenizer(WordOrNonWordRegex(), null);
90+
}
91+
92+
return new RegexPreTokenizer(WordOrNonWordRegex(), specialTokensEncoder);
93+
}
94+
95+
private const string WhiteSpacePattern = @"\S+";
4496
private static PreTokenizer? _whiteSpacePreTokenizer;
97+
4598
#if NET7_0_OR_GREATER
4699
[GeneratedRegex(WhiteSpacePattern)]
47100
private static partial Regex WhiteSpaceRegex();
@@ -50,12 +103,11 @@ public abstract partial class PreTokenizer
50103
#endif
51104

52105
/// <summary>
53-
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the word boundary.
54-
/// The word is a set of alphabet, numeric, and underscore characters.
106+
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the white spaces.
55107
/// </summary>
56108
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
57-
/// <returns>The pre-tokenizer that splits the text at the word boundary.</returns>
58-
public static PreTokenizer CreateWhiteSpace(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
109+
/// <returns>The pre-tokenizer that splits the text at the white spaces.</returns>
110+
public static PreTokenizer CreateWhiteSpacePreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
59111
{
60112
if (specialTokensEncoder is null)
61113
{

0 commit comments

Comments
 (0)