1
+ // Licensed to the .NET Foundation under one or more agreements.
2
+ // The .NET Foundation licenses this file to you under the MIT license.
3
+ // See the LICENSE file in the project root for more information.
4
+
5
+ using System ;
6
+ using System . Buffers ;
7
+ using System . Diagnostics ;
8
+ using System . Globalization ;
9
+ using System . Runtime . CompilerServices ;
10
+ using System . Runtime . InteropServices ;
11
+ using System . Text ;
12
+
13
+ namespace Microsoft . ML . Tokenizers
14
+ {
15
+ /// <summary>
16
+ /// Normalizer that performs the Bert model normalization.
17
+ /// </summary>
18
+ internal sealed class BertNormalizer : Normalizer
19
+ {
20
+ private readonly bool _doLowerCase ;
21
+ private readonly bool _tokenizeChineseChars ;
22
+ private readonly bool _stripAccents ;
23
+
24
+ /// <summary>
25
+ /// Normalize the input string.
26
+ /// </summary>
27
+ /// <param name="original">The input string to normalize.</param>
28
+ /// <returns>The normalized string.</returns>
29
+ public override string Normalize ( string original )
30
+ {
31
+ if ( string . IsNullOrEmpty ( original ) )
32
+ {
33
+ return string . Empty ;
34
+ }
35
+
36
+ if ( _stripAccents )
37
+ {
38
+ original = original . Normalize ( NormalizationForm . FormD ) ;
39
+ }
40
+
41
+ Span < char > casingBuffer = stackalloc char [ 10 ] ;
42
+ char [ ] buffer = ArrayPool < char > . Shared . Rent ( original . Length ) ;
43
+ int index = 0 ;
44
+
45
+ for ( int i = 0 ; i < original . Length ; i ++ )
46
+ {
47
+ char c = original [ i ] ;
48
+
49
+ if ( c == '\u0000 ' || c == '\uFFFD ' )
50
+ {
51
+ continue ;
52
+ }
53
+
54
+ int inc = 0 ;
55
+ int codePoint = ( int ) c ;
56
+ if ( char . IsHighSurrogate ( c ) && i + 1 < original . Length && char . IsLowSurrogate ( original [ i + 1 ] ) )
57
+ {
58
+ codePoint = char . ConvertToUtf32 ( c , original [ i + 1 ] ) ;
59
+ inc = 1 ;
60
+ }
61
+
62
+ UnicodeCategory category = CharUnicodeInfo . GetUnicodeCategory ( original , i ) ;
63
+
64
+ if ( category == UnicodeCategory . Control )
65
+ {
66
+ i += inc ;
67
+ continue ;
68
+ }
69
+
70
+ if ( category == UnicodeCategory . SpaceSeparator )
71
+ {
72
+ InsertChar ( ref buffer , ref index , ' ' ) ;
73
+ i += inc ;
74
+ continue ;
75
+ }
76
+
77
+ if ( _stripAccents && category is UnicodeCategory . NonSpacingMark or UnicodeCategory . SpacingCombiningMark )
78
+ {
79
+ i += inc ;
80
+ continue ;
81
+ }
82
+
83
+ if ( _doLowerCase && category == UnicodeCategory . UppercaseLetter )
84
+ {
85
+ int length = original . AsSpan ( ) . Slice ( i , inc + 1 ) . ToLowerInvariant ( casingBuffer ) ;
86
+ Debug . Assert ( length > 0 ) ;
87
+
88
+ InsertSpan ( ref buffer , ref index , casingBuffer . Slice ( 0 , length ) ) ;
89
+
90
+ i += inc ;
91
+ continue ;
92
+ }
93
+
94
+ if ( _tokenizeChineseChars && IsChineseChar ( codePoint ) )
95
+ {
96
+ InsertChar ( ref buffer , ref index , ' ' ) ;
97
+ InsertChar ( ref buffer , ref index , c ) ;
98
+ if ( inc > 0 )
99
+ {
100
+ InsertChar ( ref buffer , ref index , original [ i + 1 ] ) ;
101
+ }
102
+ InsertChar ( ref buffer , ref index , ' ' ) ;
103
+
104
+ i += inc ;
105
+ continue ;
106
+ }
107
+
108
+ InsertChar ( ref buffer , ref index , c ) ;
109
+ if ( inc > 0 )
110
+ {
111
+ InsertChar ( ref buffer , ref index , original [ i + 1 ] ) ;
112
+ }
113
+ i += inc ;
114
+ }
115
+
116
+ string result = index == 0 ? string . Empty : new string ( buffer , 0 , index ) . Normalize ( NormalizationForm . FormC ) ;
117
+ ArrayPool < char > . Shared . Return ( buffer ) ;
118
+ return result ;
119
+ }
120
+
121
+ /// <summary>
122
+ /// Normalize the input character span.
123
+ /// </summary>
124
+ /// <param name="original">The input character span to normalize.</param>
125
+ /// <returns>The normalized string.</returns>
126
+ public override string Normalize ( ReadOnlySpan < char > original )
127
+ {
128
+ if ( original . IsEmpty )
129
+ {
130
+ return string . Empty ;
131
+ }
132
+
133
+ return Normalize ( original . ToString ( ) ) ;
134
+ }
135
+
136
+ /// <summary>
137
+ /// Initializes a new instance of the <see cref="BertNormalizer"/> class.
138
+ /// </summary>
139
+ /// <param name="doLowerCase">Whether to lowercase the input.</param>
140
+ /// <param name="tokenizeChineseChars">Whether to tokenize Chinese characters.</param>
141
+ /// <param name="stripAccents">Whether to strip accents from the input.</param>
142
+ public BertNormalizer ( bool doLowerCase , bool tokenizeChineseChars , bool stripAccents )
143
+ {
144
+ _doLowerCase = doLowerCase ;
145
+ _tokenizeChineseChars = tokenizeChineseChars ;
146
+ _stripAccents = stripAccents ;
147
+ }
148
+
149
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
150
+ private static void InsertChar ( ref char [ ] buffer , ref int index , char c )
151
+ {
152
+ if ( index >= buffer . Length )
153
+ {
154
+ Helpers . ArrayPoolGrow ( ref buffer , index + 40 ) ;
155
+ }
156
+
157
+ buffer [ index ++ ] = c ;
158
+ }
159
+
160
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
161
+ private static void InsertSpan ( ref char [ ] buffer , ref int index , Span < char > chars )
162
+ {
163
+ if ( index + buffer . Length >= buffer . Length )
164
+ {
165
+ Helpers . ArrayPoolGrow ( ref buffer , index + buffer . Length + 10 ) ;
166
+ }
167
+
168
+ chars . CopyTo ( buffer . AsSpan ( index ) ) ;
169
+ index += chars . Length ;
170
+ }
171
+
172
+ /// <summary>
173
+ /// Checks whether CP is the codepoint of a CJK character.
174
+ /// This defines a "chinese character" as anything in the CJK Unicode block:
175
+ /// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
176
+ /// </summary>
177
+ /// <param name="codePoint">The codepoint to check.</param>
178
+ /// <remarks>
179
+ /// The CJK Unicode block is NOT all Japanese and Korean characters,
180
+ /// despite its name. The modern Korean Hangul alphabet is a different block,
181
+ /// as is Japanese Hiragana and Katakana. Those alphabets are used to write
182
+ /// space-separated words, so they are not treated specially and handled
183
+ /// like the all of the other languages.
184
+ /// </remarks>
185
+ /// <returns>True if the codepoint is a CJK character, false otherwise.</returns>
186
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
187
+ private static bool IsChineseChar ( int codePoint )
188
+ {
189
+ return ( codePoint > 0x3400 ) && // Quick check to exit early if the codepoint is outside of the CJK range
190
+ ( ( ( uint ) ( codePoint - 0x3400 ) <= ( uint ) ( 0x4DBF - 0x3400 ) ) ||
191
+ ( ( uint ) ( codePoint - 0xF900 ) <= ( uint ) ( 0xFAFF - 0xF900 ) ) ||
192
+ ( ( uint ) ( codePoint - 0x4E00 ) <= ( uint ) ( 0x9FFF - 0x4E00 ) ) ||
193
+ ( ( uint ) ( codePoint - 0x20000 ) <= ( uint ) ( 0x2A6DF - 0x20000 ) ) ||
194
+ ( ( uint ) ( codePoint - 0x2A700 ) <= ( uint ) ( 0x2B73F - 0x2A700 ) ) ||
195
+ ( ( uint ) ( codePoint - 0x2B740 ) <= ( uint ) ( 0x2B81F - 0x2B740 ) ) ||
196
+ ( ( uint ) ( codePoint - 0x2B820 ) <= ( uint ) ( 0x2CEAF - 0x2B820 ) ) ||
197
+ ( ( uint ) ( codePoint - 0x2F800 ) <= ( uint ) ( 0x2FA1F - 0x2F800 ) ) ) ;
198
+ }
199
+ }
200
+ }
0 commit comments