Skip to content

Commit

Permalink
Further optimize multithreaded processing
Browse files Browse the repository at this point in the history
  • Loading branch information
georg-jung committed Sep 18, 2023
1 parent 6a4379f commit 53f23b1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
21 changes: 13 additions & 8 deletions src/Benchmarks/TokenizeSpeed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace Benchmarks;
*/
public class TokenizeSpeed
{
private readonly List<string> _corpus;
private readonly string[] _corpus;
private readonly List<string> _otherLibCorpus;
private readonly ConcreteUncasedTokenizer _otherLibTokenizer;
private readonly BertTokenizer _tokenizer;
Expand All @@ -38,17 +38,20 @@ public TokenizeSpeed(string corpusPath, string vocabTxtFile, string tokenizerJso
using var uncompress = new BrotliStream(fs, CompressionMode.Decompress);
var dict = JsonSerializer.Deserialize<Dictionary<int, string>>(uncompress)!;

_corpus = new(dict.Count);
_corpus = new string[dict.Count];
_otherLibCorpus = new(dict.Count);
var cnt = 0;
foreach (var tx in dict.Values)
{
_corpus.Add(tx);
_corpus[cnt] = tx;

// this preprocessing gives the other lib kind of an unfair advantage, but it throws otherwise
var otherLib = tx.Substring(0, Math.Min(tx.Length, 1250)); // other lib throw if text is too long; 1250 works with 512 tokens, 1500 doesn't; 5000 works with 2048 tokens
otherLib = Regex.Replace(otherLib, @"\s+", " "); // required due to bad whitespace processing of other lib
otherLib = Regex.Replace(otherLib, @"[^A-Za-z0-9\s\.\,;:\\/?!#$%()=+\-*\""'–_`<>&^@{}[\]\|~']+", string.Empty); // other lib doesn't handle unknown characters
_otherLibCorpus.Add(otherLib);

cnt++;
}

_otherLibTokenizer = new(vocabTxtFile);
Expand Down Expand Up @@ -87,7 +90,7 @@ public object RustHuggingfaceWrapperSinglethreadedMemReuse()
[Benchmark(Baseline = true)]
public IReadOnlyCollection<object> FastBertTokenizerSinglethreadedAllocating()
{
List<object> res = new(_corpus.Count);
List<object> res = new(_corpus.Length);
foreach (var text in _corpus)
{
res.Add(_tokenizer.Tokenize(text, _maxSequenceLength));
Expand Down Expand Up @@ -115,7 +118,7 @@ public object FastBertTokenizerSingleThreadedMemReuse()
public IReadOnlyCollection<(Memory<long> InputIds, Memory<long> AttentionMask, Memory<long> TokenTypeIds)> FastBertTokenizerMultithreadedAllocating()
{
// this might be interesting to benchmark but doesn't make much sense as a real world use case
List<(Memory<long> InputIds, Memory<long> AttentionMask, Memory<long> TokenTypeIds)> res = new(_corpus.Count);
List<(Memory<long> InputIds, Memory<long> AttentionMask, Memory<long> TokenTypeIds)> res = new(_corpus.Length);
var x = _corpus.AsParallel().AsOrdered().Select(x => _tokenizer.Tokenize(x, _maxSequenceLength));
res.AddRange(x);
return res;
Expand All @@ -130,12 +133,14 @@ public object FastBertTokenizerSingleThreadedMemReuse()
var toktyp = new long[_maxSequenceLength * batchSize];
Array.Fill(toktyp, 0);

foreach (var batch in _corpus.Buffer(batchSize).Cast<IReadOnlyList<string>>())
var corpMem = _corpus.AsMemory();
for (var i = 0; i < corpMem.Length; i += batchSize)
{
var batchSeqLen = _maxSequenceLength * batch.Count;
var len = Math.Min(batchSize, corpMem.Length - i);
var batchSeqLen = _maxSequenceLength * len;
var iidsM = iids.AsMemory(0, batchSeqLen);
var attmM = attm.AsMemory(0, batchSeqLen);
_tokenizer.Tokenize(batch, iidsM, attmM, _maxSequenceLength);
_tokenizer.Tokenize(corpMem.Slice(i, len), iidsM, attmM, _maxSequenceLength);
}

return (iids.AsMemory(), attm.AsMemory(), toktyp.AsMemory());
Expand Down
28 changes: 20 additions & 8 deletions src/FastBertTokenizer/BertTokenizer.Parallel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace FastBertTokenizer;

public partial class BertTokenizer
{
private (int Count, Tuple<int, int>[] Ranges)? _rangeCache;

/// <summary>
/// Encode the given batch of input strings to token ids per the loaded vocabulary using
/// multiple threads in parallel. Write the results to the given memory areas. When encoding
Expand Down Expand Up @@ -39,9 +41,9 @@ public partial class BertTokenizer
/// The maximum number of token ids to produce for every single input in the batch.
/// Most BERT models support a maximum of 512 tokens per input.
/// </param>
public void Tokenize(IReadOnlyList<string> inputs, Memory<long> inputIds, Memory<long> attentionMask, Memory<long> tokenTypeIds, int maximumTokens = 512)
public void Tokenize(ReadOnlyMemory<string> inputs, Memory<long> inputIds, Memory<long> attentionMask, Memory<long> tokenTypeIds, int maximumTokens = 512)
{
var resultLen = maximumTokens * inputs.Count;
var resultLen = maximumTokens * inputs.Length;
if (tokenTypeIds.Length != resultLen)
{
throw new ArgumentException($"{nameof(tokenTypeIds)} must have {resultLen} elements but had {tokenTypeIds.Length}.", nameof(tokenTypeIds));
Expand All @@ -51,17 +53,26 @@ public void Tokenize(IReadOnlyList<string> inputs, Memory<long> inputIds, Memory
tokenTypeIds.Span.Fill(0);
}

/// <inheritdoc cref="Tokenize(IReadOnlyList{string}, Memory{long}, Memory{long}, Memory{long}, int)"/>
public void Tokenize(IReadOnlyList<string> inputs, Memory<long> inputIds, Memory<long> attentionMask, int maximumTokens = 512)
/// <inheritdoc cref="Tokenize(ReadOnlyMemory{string}, Memory{long}, Memory{long}, Memory{long}, int)"/>
public void Tokenize(ReadOnlyMemory<string> inputs, Memory<long> inputIds, Memory<long> attentionMask, int maximumTokens = 512)
{
var resultLen = maximumTokens * inputs.Count;
var resultLen = maximumTokens * inputs.Length;
if (inputIds.Length != resultLen || attentionMask.Length != resultLen)
{
throw new ArgumentException($"{nameof(inputIds)} and {nameof(attentionMask)} must have {resultLen} elements, but had {inputIds.Length} and {attentionMask.Length}.");
}

var rangePartitioner = Partitioner.Create(0, inputs.Count);
var ranges = rangePartitioner.GetDynamicPartitions().ToArray();
Tuple<int, int>[] ranges;
if (_rangeCache is { } rc && rc.Count == inputs.Length)
{
ranges = rc.Ranges;
}
else
{
ranges = Partitioner.Create(0, inputs.Length).GetDynamicPartitions().ToArray();
_rangeCache = (inputs.Length, ranges);
}

using var cde = new CountdownEvent(ranges.Length);
foreach (var range in ranges)
{
Expand All @@ -72,10 +83,11 @@ public void Tokenize(IReadOnlyList<string> inputs, Memory<long> inputIds, Memory

void ParallelBody((int StartInclusive, int EndExclusive) param)
{
var inputSpan = inputs.Span;
for (var i = param.StartInclusive; i < param.EndExclusive; i++)
{
var startIdx = maximumTokens * i;
var (_, nonPad) = Tokenize(inputs[i], inputIds.Slice(startIdx, maximumTokens), maximumTokens);
var (_, nonPad) = Tokenize(inputSpan[i], inputIds.Slice(startIdx, maximumTokens), maximumTokens);
var span = attentionMask.Slice(startIdx, maximumTokens).Span;
span.Slice(0, nonPad).Fill(1);
span.Slice(nonPad, maximumTokens - nonPad).Fill(0);
Expand Down

0 comments on commit 53f23b1

Please sign in to comment.