Skip to content

Commit a9b4212

Browse files
authored
Address the feedback regarding Bert tokenizer (#7280)
* Address the feedback regarding Bert tokenizer * Small fix
1 parent a7a6d88 commit a9b4212

File tree

3 files changed

+151
-42
lines changed

3 files changed

+151
-42
lines changed

src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs

Lines changed: 138 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,25 @@ public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds
290290
throw new ArgumentNullException(nameof(tokenIds0));
291291
}
292292

293-
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
294-
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
295-
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
293+
List<int> ids;
294+
295+
if (tokenIds0 is ICollection<int> c1)
296+
{
297+
int capacity = c1.Count + 2; // Add 2 for [CLS] and two [SEP] tokens.
298+
299+
if (tokenIds1 is not null)
300+
{
301+
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
302+
}
303+
304+
ids = new(capacity) { ClsTokenId };
305+
}
306+
else
307+
{
308+
// slow path
309+
ids = new List<int>(10) { ClsTokenId };
310+
}
311+
296312
ids.AddRange(tokenIds0);
297313
ids.Add(SepTokenId);
298314

@@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0,
323339
throw new ArgumentNullException(nameof(tokenIds0));
324340
}
325341

326-
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
327-
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
328-
if (buffer.Length < capacity)
342+
written = 0;
343+
if (buffer.Length < 1)
329344
{
330-
written = 0;
331345
return OperationStatus.DestinationTooSmall;
332346
}
333347

334-
written = 0;
335348
buffer[written++] = ClsTokenId;
336349
foreach (int id in tokenIds0)
337350
{
351+
if (buffer.Length <= written)
352+
{
353+
written = 0;
354+
return OperationStatus.DestinationTooSmall;
355+
}
356+
338357
buffer[written++] = id;
339358
}
359+
360+
if (buffer.Length <= written)
361+
{
362+
written = 0;
363+
return OperationStatus.DestinationTooSmall;
364+
}
340365
buffer[written++] = SepTokenId;
341366

342367
if (tokenIds1 is not null)
343368
{
344369
foreach (int id in tokenIds1)
345370
{
371+
if (buffer.Length <= written)
372+
{
373+
written = 0;
374+
return OperationStatus.DestinationTooSmall;
375+
}
346376
buffer[written++] = id;
347377
}
348378

379+
if (buffer.Length <= written)
380+
{
381+
written = 0;
382+
return OperationStatus.DestinationTooSmall;
383+
}
349384
buffer[written++] = SepTokenId;
350385
}
351386

@@ -367,11 +402,22 @@ public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnum
367402
throw new ArgumentNullException(nameof(tokenIds0));
368403
}
369404

370-
int capacity = alreadyHasSpecialTokens ?
371-
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
372-
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
405+
List<int> mask;
406+
if (tokenIds0 is ICollection<int> c1)
407+
{
408+
int capcity = c1.Count + 2;
409+
410+
if (tokenIds1 is not null)
411+
{
412+
capcity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
413+
}
373414

374-
List<int> mask = new List<int>(capacity: capacity);
415+
mask = new List<int>(capcity);
416+
}
417+
else
418+
{
419+
mask = new List<int>(10);
420+
}
375421

376422
if (!alreadyHasSpecialTokens)
377423
{
@@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
420466
throw new ArgumentNullException(nameof(tokenIds0));
421467
}
422468

423-
int capacity = alreadyHasSpecialTokens ?
424-
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
425-
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
426-
427469
written = 0;
428-
if (buffer.Length < capacity)
429-
{
430-
return OperationStatus.DestinationTooSmall;
431-
}
432-
433470
if (!alreadyHasSpecialTokens)
434471
{
472+
if (buffer.Length < 1)
473+
{
474+
return OperationStatus.DestinationTooSmall;
475+
}
435476
buffer[written++] = 1; // CLS
477+
436478
foreach (int id in tokenIds0)
437479
{
480+
if (buffer.Length <= written)
481+
{
482+
written = 0;
483+
return OperationStatus.DestinationTooSmall;
484+
}
438485
buffer[written++] = 0;
439486
}
487+
488+
if (buffer.Length <= written)
489+
{
490+
written = 0;
491+
return OperationStatus.DestinationTooSmall;
492+
}
440493
buffer[written++] = 1; // SEP
441494

442495
if (tokenIds1 is not null)
443496
{
444497
foreach (int id in tokenIds1)
445498
{
499+
if (buffer.Length <= written)
500+
{
501+
written = 0;
502+
return OperationStatus.DestinationTooSmall;
503+
}
446504
buffer[written++] = 0;
447505
}
506+
507+
if (buffer.Length <= written)
508+
{
509+
written = 0;
510+
return OperationStatus.DestinationTooSmall;
511+
}
448512
buffer[written++] = 1; // SEP
449513
}
450514

@@ -453,13 +517,23 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
453517

454518
foreach (int id in tokenIds0)
455519
{
520+
if (buffer.Length <= written)
521+
{
522+
written = 0;
523+
return OperationStatus.DestinationTooSmall;
524+
}
456525
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
457526
}
458527

459528
if (tokenIds1 is not null)
460529
{
461530
foreach (int id in tokenIds1)
462531
{
532+
if (buffer.Length <= written)
533+
{
534+
written = 0;
535+
return OperationStatus.DestinationTooSmall;
536+
}
463537
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
464538
}
465539
}
@@ -484,21 +558,38 @@ public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> token
484558
throw new ArgumentNullException(nameof(tokenIds0));
485559
}
486560

487-
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
488-
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
561+
List<int> typeIds;
562+
if (tokenIds0 is ICollection<int> c1)
563+
{
564+
int capacity = c1.Count + 2; // Add 2 for [CLS] and [SEP] tokens.
565+
566+
if (tokenIds1 is not null)
567+
{
568+
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
569+
}
489570

490-
List<int> typeIds = new List<int>(capacity);
491-
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
571+
typeIds = new List<int>(capacity);
572+
}
573+
else
574+
{
575+
typeIds = new List<int>(10);
576+
}
577+
578+
foreach (var id in tokenIds0)
492579
{
493580
typeIds.Add(0);
494581
}
582+
typeIds.Add(0); // [CLS]
583+
typeIds.Add(0); // [SEP]
495584

496585
if (tokenIds1 is not null)
497586
{
498-
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
587+
foreach (int id in tokenIds1)
499588
{
500589
typeIds.Add(1);
501590
}
591+
592+
typeIds.Add(1); // [SEP]
502593
}
503594

504595
return typeIds;
@@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds
515606

516607
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
517608
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
518-
if (buffer.Length < capacity)
609+
if (buffer.Length < 2)
519610
{
520611
return OperationStatus.DestinationTooSmall;
521612
}
613+
buffer[written++] = 0; // [CLS]
614+
buffer[written++] = 0; // [SEP]
522615

523-
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
616+
foreach (int id in tokenIds0)
524617
{
618+
if (buffer.Length <= written)
619+
{
620+
written = 0;
621+
return OperationStatus.DestinationTooSmall;
622+
}
525623
buffer[written++] = 0;
526624
}
527625

528626
if (tokenIds1 is not null)
529627
{
530-
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
628+
foreach (int id in tokenIds1)
531629
{
630+
if (buffer.Length <= written)
631+
{
632+
written = 0;
633+
return OperationStatus.DestinationTooSmall;
634+
}
532635
buffer[written++] = 1;
533636
}
637+
638+
if (buffer.Length < written)
639+
{
640+
return OperationStatus.DestinationTooSmall;
641+
}
642+
buffer[written++] = 1; // [SEP]
534643
}
535644

536645
return OperationStatus.Done;

src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ await CreateAsync(
233233
continuingSubwordPrefix,
234234
maxInputCharsPerWord,
235235
cancellationToken,
236-
disposeStream: true);
236+
disposeStream: true).ConfigureAwait(false);
237237

238238
/// <summary>
239239
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
@@ -259,7 +259,7 @@ public static async Task<WordPieceTokenizer> CreateAsync(
259259
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
260260
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
261261
CancellationToken cancellationToken = default) =>
262-
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
262+
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false).ConfigureAwait(false);
263263

264264
private static async Task<WordPieceTokenizer> CreateAsync(
265265
Stream vocabStream,

src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public override string Normalize(string original)
6969

7070
if (category == UnicodeCategory.SpaceSeparator)
7171
{
72-
InsertChar(ref buffer, ref index, ' ');
72+
AddChar(ref buffer, ref index, ' ');
7373
i += inc;
7474
continue;
7575
}
@@ -85,30 +85,30 @@ public override string Normalize(string original)
8585
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
8686
Debug.Assert(length > 0);
8787

88-
InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
88+
AddSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
8989

9090
i += inc;
9191
continue;
9292
}
9393

9494
if (_tokenizeChineseChars && IsChineseChar(codePoint))
9595
{
96-
InsertChar(ref buffer, ref index, ' ');
97-
InsertChar(ref buffer, ref index, c);
96+
AddChar(ref buffer, ref index, ' ');
97+
AddChar(ref buffer, ref index, c);
9898
if (inc > 0)
9999
{
100-
InsertChar(ref buffer, ref index, original[i + 1]);
100+
AddChar(ref buffer, ref index, original[i + 1]);
101101
}
102-
InsertChar(ref buffer, ref index, ' ');
102+
AddChar(ref buffer, ref index, ' ');
103103

104104
i += inc;
105105
continue;
106106
}
107107

108-
InsertChar(ref buffer, ref index, c);
108+
AddChar(ref buffer, ref index, c);
109109
if (inc > 0)
110110
{
111-
InsertChar(ref buffer, ref index, original[i + 1]);
111+
AddChar(ref buffer, ref index, original[i + 1]);
112112
}
113113
i += inc;
114114
}
@@ -147,7 +147,7 @@ public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAcc
147147
}
148148

149149
[MethodImpl(MethodImplOptions.AggressiveInlining)]
150-
private static void InsertChar(ref char[] buffer, ref int index, char c)
150+
private static void AddChar(ref char[] buffer, ref int index, char c)
151151
{
152152
if (index >= buffer.Length)
153153
{
@@ -158,9 +158,9 @@ private static void InsertChar(ref char[] buffer, ref int index, char c)
158158
}
159159

160160
[MethodImpl(MethodImplOptions.AggressiveInlining)]
161-
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
161+
private static void AddSpan(ref char[] buffer, ref int index, Span<char> chars)
162162
{
163-
if (index + buffer.Length >= buffer.Length)
163+
if (index + chars.Length >= buffer.Length)
164164
{
165165
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
166166
}

0 commit comments

Comments
 (0)