@@ -290,9 +290,25 @@ public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds
290
290
throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
291
291
}
292
292
293
- // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
294
- int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
295
- List < int > ids = new List < int > ( capacity : capacity ) { ClsTokenId } ;
293
+ List < int > ids ;
294
+
295
+ if ( tokenIds0 is ICollection < int > c1 )
296
+ {
297
+ int capacity = c1 . Count + 2 ; // Add 2 for [CLS] and two [SEP] tokens.
298
+
299
+ if ( tokenIds1 is not null )
300
+ {
301
+ capacity += tokenIds1 is ICollection < int > c2 ? c2 . Count + 1 : c1 . Count + 1 ;
302
+ }
303
+
304
+ ids = new ( capacity ) { ClsTokenId } ;
305
+ }
306
+ else
307
+ {
308
+ // slow path
309
+ ids = new List < int > ( 10 ) { ClsTokenId } ;
310
+ }
311
+
296
312
ids . AddRange ( tokenIds0 ) ;
297
313
ids . Add ( SepTokenId ) ;
298
314
@@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0,
323
339
throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
324
340
}
325
341
326
- // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
327
- int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
328
- if ( buffer . Length < capacity )
342
+ written = 0 ;
343
+ if ( buffer . Length < 1 )
329
344
{
330
- written = 0 ;
331
345
return OperationStatus . DestinationTooSmall ;
332
346
}
333
347
334
- written = 0 ;
335
348
buffer [ written ++ ] = ClsTokenId ;
336
349
foreach ( int id in tokenIds0 )
337
350
{
351
+ if ( buffer . Length <= written )
352
+ {
353
+ written = 0 ;
354
+ return OperationStatus . DestinationTooSmall ;
355
+ }
356
+
338
357
buffer [ written ++ ] = id ;
339
358
}
359
+
360
+ if ( buffer . Length <= written )
361
+ {
362
+ written = 0 ;
363
+ return OperationStatus . DestinationTooSmall ;
364
+ }
340
365
buffer [ written ++ ] = SepTokenId ;
341
366
342
367
if ( tokenIds1 is not null )
343
368
{
344
369
foreach ( int id in tokenIds1 )
345
370
{
371
+ if ( buffer . Length <= written )
372
+ {
373
+ written = 0 ;
374
+ return OperationStatus . DestinationTooSmall ;
375
+ }
346
376
buffer [ written ++ ] = id ;
347
377
}
348
378
379
+ if ( buffer . Length <= written )
380
+ {
381
+ written = 0 ;
382
+ return OperationStatus . DestinationTooSmall ;
383
+ }
349
384
buffer [ written ++ ] = SepTokenId ;
350
385
}
351
386
@@ -367,11 +402,22 @@ public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnum
367
402
throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
368
403
}
369
404
370
- int capacity = alreadyHasSpecialTokens ?
371
- tokenIds0 . Count ( ) + ( tokenIds1 ? . Count ( ) ?? 0 ) :
372
- tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : 1 ) ; // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
405
+ List < int > mask ;
406
+ if ( tokenIds0 is ICollection < int > c1 )
407
+ {
408
+ int capcity = c1 . Count + 2 ;
409
+
410
+ if ( tokenIds1 is not null )
411
+ {
412
+ capcity += tokenIds1 is ICollection < int > c2 ? c2 . Count + 1 : c1 . Count + 1 ;
413
+ }
373
414
374
- List < int > mask = new List < int > ( capacity : capacity ) ;
415
+ mask = new List < int > ( capcity ) ;
416
+ }
417
+ else
418
+ {
419
+ mask = new List < int > ( 10 ) ;
420
+ }
375
421
376
422
if ( ! alreadyHasSpecialTokens )
377
423
{
@@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
420
466
throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
421
467
}
422
468
423
- int capacity = alreadyHasSpecialTokens ?
424
- tokenIds0 . Count ( ) + ( tokenIds1 ? . Count ( ) ?? 0 ) :
425
- tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ; // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
426
-
427
469
written = 0 ;
428
- if ( buffer . Length < capacity )
429
- {
430
- return OperationStatus . DestinationTooSmall ;
431
- }
432
-
433
470
if ( ! alreadyHasSpecialTokens )
434
471
{
472
+ if ( buffer . Length < 1 )
473
+ {
474
+ return OperationStatus . DestinationTooSmall ;
475
+ }
435
476
buffer [ written ++ ] = 1 ; // CLS
477
+
436
478
foreach ( int id in tokenIds0 )
437
479
{
480
+ if ( buffer . Length <= written )
481
+ {
482
+ written = 0 ;
483
+ return OperationStatus . DestinationTooSmall ;
484
+ }
438
485
buffer [ written ++ ] = 0 ;
439
486
}
487
+
488
+ if ( buffer . Length <= written )
489
+ {
490
+ written = 0 ;
491
+ return OperationStatus . DestinationTooSmall ;
492
+ }
440
493
buffer [ written ++ ] = 1 ; // SEP
441
494
442
495
if ( tokenIds1 is not null )
443
496
{
444
497
foreach ( int id in tokenIds1 )
445
498
{
499
+ if ( buffer . Length <= written )
500
+ {
501
+ written = 0 ;
502
+ return OperationStatus . DestinationTooSmall ;
503
+ }
446
504
buffer [ written ++ ] = 0 ;
447
505
}
506
+
507
+ if ( buffer . Length <= written )
508
+ {
509
+ written = 0 ;
510
+ return OperationStatus . DestinationTooSmall ;
511
+ }
448
512
buffer [ written ++ ] = 1 ; // SEP
449
513
}
450
514
@@ -453,13 +517,23 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
453
517
454
518
foreach ( int id in tokenIds0 )
455
519
{
520
+ if ( buffer . Length <= written )
521
+ {
522
+ written = 0 ;
523
+ return OperationStatus . DestinationTooSmall ;
524
+ }
456
525
buffer [ written ++ ] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0 ;
457
526
}
458
527
459
528
if ( tokenIds1 is not null )
460
529
{
461
530
foreach ( int id in tokenIds1 )
462
531
{
532
+ if ( buffer . Length <= written )
533
+ {
534
+ written = 0 ;
535
+ return OperationStatus . DestinationTooSmall ;
536
+ }
463
537
buffer [ written ++ ] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0 ;
464
538
}
465
539
}
@@ -484,21 +558,38 @@ public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> token
484
558
throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
485
559
}
486
560
487
- // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
488
- int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
561
+ List < int > typeIds ;
562
+ if ( tokenIds0 is ICollection < int > c1 )
563
+ {
564
+ int capacity = c1 . Count + 2 ; // Add 2 for [CLS] and [SEP] tokens.
565
+
566
+ if ( tokenIds1 is not null )
567
+ {
568
+ capacity += tokenIds1 is ICollection < int > c2 ? c2 . Count + 1 : c1 . Count + 1 ;
569
+ }
489
570
490
- List < int > typeIds = new List < int > ( capacity ) ;
491
- for ( int i = 0 ; i < tokenIds0 . Count ( ) + 2 ; i ++ ) // Add 2 for [CLS] and [SEP] tokens.
571
+ typeIds = new List < int > ( capacity ) ;
572
+ }
573
+ else
574
+ {
575
+ typeIds = new List < int > ( 10 ) ;
576
+ }
577
+
578
+ foreach ( var id in tokenIds0 )
492
579
{
493
580
typeIds . Add ( 0 ) ;
494
581
}
582
+ typeIds . Add ( 0 ) ; // [CLS]
583
+ typeIds . Add ( 0 ) ; // [SEP]
495
584
496
585
if ( tokenIds1 is not null )
497
586
{
498
- for ( int i = 0 ; i < tokenIds1 . Count ( ) + 1 ; i ++ ) // Add 1 for [SEP] token.
587
+ foreach ( int id in tokenIds1 )
499
588
{
500
589
typeIds . Add ( 1 ) ;
501
590
}
591
+
592
+ typeIds . Add ( 1 ) ; // [SEP]
502
593
}
503
594
504
595
return typeIds ;
@@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds
515
606
516
607
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
517
608
int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
518
- if ( buffer . Length < capacity )
609
+ if ( buffer . Length < 2 )
519
610
{
520
611
return OperationStatus . DestinationTooSmall ;
521
612
}
613
+ buffer [ written ++ ] = 0 ; // [CLS]
614
+ buffer [ written ++ ] = 0 ; // [SEP]
522
615
523
- for ( int i = 0 ; i < tokenIds0 . Count ( ) + 2 ; i ++ ) // Add 2 for [CLS] and [SEP] tokens.
616
+ foreach ( int id in tokenIds0 )
524
617
{
618
+ if ( buffer . Length <= written )
619
+ {
620
+ written = 0 ;
621
+ return OperationStatus . DestinationTooSmall ;
622
+ }
525
623
buffer [ written ++ ] = 0 ;
526
624
}
527
625
528
626
if ( tokenIds1 is not null )
529
627
{
530
- for ( int i = 0 ; i < tokenIds1 . Count ( ) + 1 ; i ++ ) // Add 1 for [SEP] token.
628
+ foreach ( int id in tokenIds1 )
531
629
{
630
+ if ( buffer . Length <= written )
631
+ {
632
+ written = 0 ;
633
+ return OperationStatus . DestinationTooSmall ;
634
+ }
532
635
buffer [ written ++ ] = 1 ;
533
636
}
637
+
638
+ if ( buffer . Length < written )
639
+ {
640
+ return OperationStatus . DestinationTooSmall ;
641
+ }
642
+ buffer [ written ++ ] = 1 ; // [SEP]
534
643
}
535
644
536
645
return OperationStatus . Done ;
0 commit comments