From 9c972d99ce1543a63edacf3667a31b8afa442ecd Mon Sep 17 00:00:00 2001 From: Berkay Ersoy Date: Fri, 8 Nov 2024 00:05:36 +0300 Subject: [PATCH] feat: FromPretrained to load tokenizer directly from HF (#27) * add LoadTokenizerFromHuggingFace function to load tokenizer directly from huggingface, update README.md * using channels as unbuffered channel, update channel names and minimize some approaches * fix: rename new func to FromPretrained, improve example * fix: clean up downloadFile * fix: concurrency issues in case of an error * fix: make optional params optional * fix: cache path has to be model specific * add unit tests for `FromPretrained` * migrate to table driven tests, unify/simplify test cases * fix: clean up nits --------- Co-authored-by: Resul Berkay Ersoy Co-authored-by: Daulet Zhanguzin --- README.md | 44 ++++++++++++- example/main.go | 61 ++++++++++++++++- tokenizer.go | 162 +++++++++++++++++++++++++++++++++++++++++++++- tokenizer_test.go | 106 ++++++++++++++++++++++++++++++ 4 files changed, 369 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 119606ac..2ea06a3a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Go bindings for the [HuggingFace Tokenizers](https://github.com/huggingface/toke ## Installation -`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers.a'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers.a"` to avoid specifying it every time. +`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers/directory'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers/directory"` to avoid specifying it every time. ### Using pre-built binaries @@ -31,6 +31,20 @@ if err != nil { defer tk.Close() ``` +Load a tokenizer from Huggingface: + +```go +import "github.com/daulet/tokenizers" + +tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased" +tk, err := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil) +if err != nil { + return err +} +// release native resources +defer tk.Close() +``` + Encode text and decode tokens: ```go @@ -44,6 +58,34 @@ fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true // brown fox jumps over the lazy dog ``` +Encode text with options: + +```go +var encodeOptions []tokenizers.EncodeOption +encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask()) + +// Or just basically +// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) + +encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", false, encodeOptions...) +fmt.Println(encodingResponse.IDs) +// [2829 4419 14523 2058 1996 13971 3899] +fmt.Println(encodingResponse.TypeIDs) +// [0 0 0 0 0 0 0] +fmt.Println(encodingResponse.SpecialTokensMask) +// [0 0 0 0 0 0 0] +fmt.Println(encodingResponse.AttentionMask) +// [1 1 1 1 1 1 1] +fmt.Println(encodingResponse.Tokens) +// [brown fox jumps over the lazy dog] +fmt.Println(encodingResponse.Offsets) +// [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]] +``` + ## Benchmarks `go test . -run=^\$ -bench=. -benchmem -count=10 > test/benchmark/$(git rev-parse HEAD).txt` diff --git a/example/main.go b/example/main.go index 43d07830..db65bf51 100644 --- a/example/main.go +++ b/example/main.go @@ -2,17 +2,19 @@ package main import ( "fmt" + "log" "github.com/daulet/tokenizers" ) -func main() { +func simple() error { tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json") if err != nil { - panic(err) + return err } // release native resources defer tk.Close() + fmt.Println("Vocab size:", tk.VocabSize()) // Vocab size: 30522 fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false)) @@ -21,4 +23,59 @@ func main() { // [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]] fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog + return nil +} + +func advanced() error { + // Load tokenizer from local config file + tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json") + if err != nil { + return err + } + defer tk.Close() + + // Load pretrained tokenizer from HuggingFace + tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", tokenizers.WithCacheDir("./.cache/tokenizers")) + if err != nil { + return err + } + defer tkFromHf.Close() + + // Encode with specific options + encodeOptions := []tokenizers.EncodeOption{ + tokenizers.WithReturnTypeIDs(), + tokenizers.WithReturnAttentionMask(), + tokenizers.WithReturnTokens(), + tokenizers.WithReturnOffsets(), + tokenizers.WithReturnSpecialTokensMask(), + } + // Or simply: + // encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) + + // regardless of how the tokenizer was initialized, the output is the same + for _, tkzr := range []*tokenizers.Tokenizer{tk, tkFromHf} { + encodingResponse := tkzr.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) + fmt.Println(encodingResponse.IDs) + // [101 2829 4419 14523 2058 1996 13971 3899 102] + fmt.Println(encodingResponse.TypeIDs) + // [0 0 0 0 0 0 0 0 0] + fmt.Println(encodingResponse.SpecialTokensMask) + // [1 0 0 0 0 0 0 0 1] + fmt.Println(encodingResponse.AttentionMask) + // [1 1 1 1 1 1 1 1 1] + fmt.Println(encodingResponse.Tokens) + // [[CLS] brown fox jumps over the lazy dog [SEP]] + fmt.Println(encodingResponse.Offsets) + // [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]] + } + return nil +} + +func main() { + if err := simple(); err != nil { + log.Fatal(err) + } + if err := advanced(); err != nil { + log.Fatal(err) + } } diff --git a/tokenizer.go b/tokenizer.go index 08b9e6bd..765163b5 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -13,10 +13,29 @@ import "C" import ( "fmt" "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" "unsafe" ) -const WANT_VERSION = "1.20.2" +const ( + WANT_VERSION = "1.20.2" + + baseURL = "https://huggingface.co" +) + +// List of necessary tokenizer files and their mandatory status. +// True means mandatory, false means optional. +var tokenizerFiles = map[string]bool{ + "tokenizer.json": true, + "vocab.txt": false, + "merges.txt": false, + "special_tokens_map.json": false, + "added_tokens.json": false, +} func init() { version := C.version() @@ -78,6 +97,147 @@ func FromFile(path string) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } +type tokenizerConfig struct { + cacheDir *string + authToken *string +} + +type TokenizerConfigOption func(cfg *tokenizerConfig) + +func WithCacheDir(path string) TokenizerConfigOption { + return func(cfg *tokenizerConfig) { + cfg.cacheDir = &path + } +} + +func WithAuthToken(token string) TokenizerConfigOption { + return func(cfg *tokenizerConfig) { + cfg.authToken = &token + } +} + +// FromPretrained downloads necessary files and initializes the tokenizer. +// Parameters: +// - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased"). +// - destination: Optional. If provided and not nil, files will be downloaded to this folder. +// If nil, a temporary directory will be used. +// - authToken: Optional. If provided and not nil, it will be used to authenticate requests. +func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, error) { + cfg := &tokenizerConfig{} + for _, opt := range opts { + opt(cfg) + } + if strings.TrimSpace(modelID) == "" { + return nil, fmt.Errorf("modelID cannot be empty") + } + + // Construct the model URL + modelURL := fmt.Sprintf("%s/%s/resolve/main", baseURL, modelID) + + // Determine the download directory + var downloadDir string + if cfg.cacheDir != nil { + downloadDir = fmt.Sprintf("%s/%s", *cfg.cacheDir, modelID) + // Create the destination directory if it doesn't exist + err := os.MkdirAll(downloadDir, os.ModePerm) + if err != nil { + return nil, fmt.Errorf("failed to create destination directory %s: %w", downloadDir, err) + } + } else { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "huggingface-tokenizer-*") + if err != nil { + return nil, fmt.Errorf("error creating temporary directory: %w", err) + } + downloadDir = tmpDir + } + + var wg sync.WaitGroup + errCh := make(chan error) + + // Download each tokenizer file concurrently + for filename, isMandatory := range tokenizerFiles { + wg.Add(1) + go func(fn string, mandatory bool) { + defer wg.Done() + fileURL := fmt.Sprintf("%s/%s", modelURL, fn) + destPath := filepath.Join(downloadDir, fn) + err := downloadFile(fileURL, destPath, cfg.authToken) + if err != nil && mandatory { + // If the file is mandatory, report an error + errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) + } + }(filename, isMandatory) + } + + go func() { + wg.Wait() + close(errCh) + }() + + var errs []error + for err := range errCh { + errs = append(errs, err) + } + + if len(errs) > 0 { + if err := os.RemoveAll(downloadDir); err != nil { + fmt.Printf("Warning: failed to clean up directory %s: %v\n", downloadDir, err) + } + return nil, errs[0] + } + + return FromFile(filepath.Join(downloadDir, "tokenizer.json")) +} + +// downloadFile downloads a file from the given URL and saves it to the specified destination. +// If authToken is provided (non-nil), it will be used for authorization. +// Returns an error if the download fails. +func downloadFile(url, destination string, authToken *string) error { + // Check if the file already exists + if _, err := os.Stat(destination); err == nil { + return nil + } + + // Create a new HTTP request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create request for %s: %w", url, err) + } + + // If authToken is provided, set the Authorization header + if authToken != nil { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *authToken)) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to download from %s: %w", url, err) + } + defer resp.Body.Close() + + // Check for successful response + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download from %s: status code %d", url, resp.StatusCode) + } + + // Create the destination file + out, err := os.Create(destination) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", destination, err) + } + defer out.Close() + + // Write the response body to the file + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("failed to write to file %s: %w", destination, err) + } + + fmt.Printf("Successfully downloaded %s\n", destination) + return nil +} + func (t *Tokenizer) Close() error { C.free_tokenizer(t.tokenizer) t.tokenizer = nil diff --git a/tokenizer_test.go b/tokenizer_test.go index da2c1fda..62547d5c 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -3,6 +3,8 @@ package tokenizers_test import ( _ "embed" "math/rand" + "os" + "path/filepath" "testing" "github.com/daulet/tokenizers" @@ -472,3 +474,107 @@ func BenchmarkDecodeNTokens(b *testing.B) { // a token is one or more characters assert.Greater(b, len(text), b.N) } + +func TestFromPretrained(t *testing.T) { + tests := []struct { + name string + modelID string + setupOpts func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) + wantErr bool + expectedToken bool + }{ + { + name: "valid public model with cache dir", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + }, tmpDir + }, + }, + { + name: "valid public model without cache dir", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + return nil, "" + }, + }, + { + name: "private model with valid auth token", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + tokenizers.WithAuthToken("test-token"), + }, tmpDir + }, + }, + { + name: "private model with invalid auth token", + modelID: "private-model", + wantErr: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + tokenizers.WithAuthToken("invalid-token"), + }, tmpDir + }, + }, + { + name: "empty model ID", + modelID: "", + wantErr: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + return nil, "" + }, + }, + { + name: "nonexistent model", + modelID: "nonexistent/model", + wantErr: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + }, tmpDir + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts, cacheDir := tt.setupOpts(t) + tokenizer, err := tokenizers.FromPretrained(tt.modelID, opts...) + + if gotErr := err != nil; gotErr != tt.wantErr { + t.Fatalf("expected error: %v, got error: %v", tt.wantErr, err) + } + if tt.wantErr { + return + } + if cacheDir != "" { + validateCache(t, cacheDir, tt.modelID) + } + if err := tokenizer.Close(); err != nil { + t.Fatalf("error closing tokenizer: %v", err) + } + }) + } +} + +func validateCache(t *testing.T, dir string, modelID string) { + t.Helper() + files := []string{"tokenizer.json", "vocab.txt"} + for _, file := range files { + path := filepath.Join(dir, modelID, file) + if _, err := os.Stat(path); err != nil { + t.Errorf("expected file %s to exist in cache for model %s", file, modelID) + } + } +}