From ce1fe87d7a561cf3e1e852f004217d4d3f498f4d Mon Sep 17 00:00:00 2001 From: Berkay Ersoy Date: Thu, 24 Oct 2024 22:51:11 +0300 Subject: [PATCH 01/10] add LoadTokenizerFromHuggingFace function to load tokenizer directly from huggingface, update README.md --- README.md | 44 ++++++++++++- example/main.go | 48 +++++++++++++- tokenizer.go | 170 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 259 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 119606a..6883f71 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Go bindings for the [HuggingFace Tokenizers](https://github.com/huggingface/toke ## Installation -`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers.a'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers.a"` to avoid specifying it every time. +`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers-directory' -ltokenizers'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers-directory" -ltokenizers` to avoid specifying it every time. ### Using pre-built binaries @@ -31,6 +31,20 @@ if err != nil { defer tk.Close() ``` +Load a tokenizer from Huggingface: + +```go +import "github.com/daulet/tokenizers" + +tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased" +tk, err := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil) +if err != nil { + return err +} +// release native resources +defer tk.Close() +``` + Encode text and decode tokens: ```go @@ -44,6 +58,34 @@ fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true // brown fox jumps over the lazy dog ``` +Encode text with options: + +```go +var encodeOptions []tokenizers.EncodeOption +encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets()) +encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask()) + +// Or just basically +// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) + +encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", false, encodeOptions...) +fmt.Println(encodingResponse.IDs) +// [2829 4419 14523 2058 1996 13971 3899] +fmt.Println(encodingResponse.TypeIDs) +// [0 0 0 0 0 0 0] +fmt.Println(encodingResponse.SpecialTokensMask) +// [0 0 0 0 0 0 0] +fmt.Println(encodingResponse.AttentionMask) +// [1 1 1 1 1 1 1] +fmt.Println(encodingResponse.Tokens) +// [brown fox jumps over the lazy dog] +fmt.Println(encodingResponse.Offsets) +// [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]] +``` + ## Benchmarks `go test . -run=^\$ -bench=. -benchmem -count=10 > test/benchmark/$(git rev-parse HEAD).txt` diff --git a/example/main.go b/example/main.go index 43d0783..3c73732 100644 --- a/example/main.go +++ b/example/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "github.com/daulet/tokenizers" ) @@ -13,6 +12,7 @@ func main() { } // release native resources defer tk.Close() + fmt.Println("Vocab size:", tk.VocabSize()) // Vocab size: 30522 fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false)) @@ -21,4 +21,50 @@ func main() { // [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]] fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog + + var encodeOptions []tokenizers.EncodeOption + encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs()) + encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask()) + encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens()) + encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets()) + encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask()) + + // Or just basically + // encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) + + encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) + fmt.Println(encodingResponse.IDs) + // [2829 4419 14523 2058 1996 13971 3899] + fmt.Println(encodingResponse.TypeIDs) + // [0 0 0 0 0 0 0] + fmt.Println(encodingResponse.SpecialTokensMask) + // [0 0 0 0 0 0 0] + fmt.Println(encodingResponse.AttentionMask) + // [1 1 1 1 1 1 1] + fmt.Println(encodingResponse.Tokens) + // [brown fox jumps over the lazy dog] + fmt.Println(encodingResponse.Offsets) + // [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]] + + tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased" + tkFromHf, errHf := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil) + if errHf != nil { + panic(errHf) + } + // release native resources + defer tkFromHf.Close() + + encodingResponseHf := tkFromHf.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) + fmt.Println(encodingResponseHf.IDs) + // [101 2829 4419 14523 2058 1996 13971 3899 102] + fmt.Println(encodingResponseHf.TypeIDs) + // [0 0 0 0 0 0 0 0 0] + fmt.Println(encodingResponseHf.SpecialTokensMask) + // [1 0 0 0 0 0 0 0 1] + fmt.Println(encodingResponseHf.AttentionMask) + // [1 1 1 1 1 1 1 1 1] + fmt.Println(encodingResponseHf.Tokens) + // [[CLS] brown fox jumps over the lazy dog [SEP]] + fmt.Println(encodingResponseHf.Offsets) + // [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]] } diff --git a/tokenizer.go b/tokenizer.go index 08b9e6b..9998185 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -13,10 +13,29 @@ import "C" import ( "fmt" "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" "unsafe" ) -const WANT_VERSION = "1.20.2" +const ( + WANT_VERSION = "1.20.2" + + baseURL = "https://huggingface.co" +) + +// List of necessary tokenizer files and their mandatory status. +// True means mandatory, false means optional. +var tokenizerFiles = map[string]bool{ + "tokenizer.json": true, + "vocab.txt": false, + "merges.txt": false, + "special_tokens_map.json": false, + "added_tokens.json": false, +} func init() { version := C.version() @@ -78,6 +97,155 @@ func FromFile(path string) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } +// LoadTokenizerFromHuggingFace downloads necessary files and initializes the tokenizer. +// Parameters: +// - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased"). +// - destination: Optional. If provided and not nil, files will be downloaded to this folder. +// If nil, a temporary directory will be used. +// - authToken: Optional. If provided and not nil, it will be used to authenticate requests. +func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string) (*Tokenizer, error) { + if strings.TrimSpace(modelID) == "" { + return nil, fmt.Errorf("modelID cannot be empty") + } + + // Construct the model URL + modelURL := fmt.Sprintf("%s/%s/resolve/main", baseURL, modelID) + + // Determine the download directory + var downloadDir string + if destination != nil && *destination != "" { + downloadDir = *destination + // Create the destination directory if it doesn't exist + err := os.MkdirAll(downloadDir, os.ModePerm) + if err != nil { + return nil, fmt.Errorf("failed to create destination directory %s: %w", downloadDir, err) + } + } else { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "huggingface-tokenizer-*") + if err != nil { + return nil, fmt.Errorf("error creating temporary directory: %w", err) + } + downloadDir = tmpDir + } + + var wg sync.WaitGroup + errChan := make(chan error, len(tokenizerFiles)) + + // Mutex for synchronized logging + var logMutex sync.Mutex + + // Download each tokenizer file concurrently + for filename, isMandatory := range tokenizerFiles { + wg.Add(1) + go func(fn string, mandatory bool) { + defer wg.Done() + fileURL := fmt.Sprintf("%s/%s", modelURL, fn) + destPath := filepath.Join(downloadDir, fn) + err := downloadFile(fileURL, destPath, authToken) + if err != nil { + if mandatory { + // If the file is mandatory, report an error + errChan <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) + } else { + // Optional files: log warning and continue + logMutex.Lock() + fmt.Printf("Warning: failed to download optional file %s: %v\n", fn, err) + logMutex.Unlock() + } + } + }(filename, isMandatory) + } + + // Wait for all downloads to complete + wg.Wait() + close(errChan) + + // Check for errors during downloads + for downloadErr := range errChan { + if downloadErr != nil { + // Clean up the directory and return the error + cleanupDirectory(downloadDir) + return nil, downloadErr + } + } + + // Verify that tokenizer.json exists + tokenizerPath := filepath.Join(downloadDir, "tokenizer.json") + if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) { + return nil, fmt.Errorf("mandatory file tokenizer.json does not exist in %s", downloadDir) + } + + // Initialize the tokenizer using the downloaded tokenizer.json + tokenizer, err := FromFile(tokenizerPath) + if err != nil { + return nil, err + } + + return tokenizer, nil +} + +// downloadFile downloads a file from the given URL and saves it to the specified destination. +// If authToken is provided (non-nil), it will be used for authorization. +// Returns an error if the download fails. +func downloadFile(url, destination string, authToken *string) error { + // Check if the file already exists + if _, err := os.Stat(destination); err == nil { + fmt.Printf("File %s already exists. Skipping download.\n", destination) + return nil + } + + // Create the destination file + out, err := os.Create(destination) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", destination, err) + } + defer out.Close() + + // Create a new HTTP request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create request for %s: %w", url, err) + } + + // If authToken is provided, set the Authorization header + if authToken != nil && *authToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *authToken)) + } + + // Perform the HTTP request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to download from %s: %w", url, err) + } + defer resp.Body.Close() + + // Check for successful response + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download from %s: status code %d", url, resp.StatusCode) + } + + // Write the response body to the file + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("failed to write to file %s: %w", destination, err) + } + + fmt.Printf("Successfully downloaded %s\n", destination) + return nil +} + +// cleanupDirectory removes the specified directory and all its contents. +func cleanupDirectory(dir string) { + err := os.RemoveAll(dir) + if err != nil { + fmt.Printf("Warning: failed to clean up directory %s: %v\n", dir, err) + } else { + fmt.Printf("Successfully cleaned up directory %s\n", dir) + } +} + func (t *Tokenizer) Close() error { C.free_tokenizer(t.tokenizer) t.tokenizer = nil From 9b9e41702792ae815b508ab43233c2968bc6a64c Mon Sep 17 00:00:00 2001 From: Resul Berkay Ersoy Date: Tue, 29 Oct 2024 18:31:03 +0300 Subject: [PATCH 02/10] using channels as unbuffered channel, update channel names and minimize some approaches --- tokenizer.go | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/tokenizer.go b/tokenizer.go index 9998185..f44fd6d 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -130,10 +130,7 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string } var wg sync.WaitGroup - errChan := make(chan error, len(tokenizerFiles)) - - // Mutex for synchronized logging - var logMutex sync.Mutex + errCh := make(chan error) // Download each tokenizer file concurrently for filename, isMandatory := range tokenizerFiles { @@ -146,12 +143,10 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string if err != nil { if mandatory { // If the file is mandatory, report an error - errChan <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) + errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) } else { // Optional files: log warning and continue - logMutex.Lock() fmt.Printf("Warning: failed to download optional file %s: %v\n", fn, err) - logMutex.Unlock() } } }(filename, isMandatory) @@ -159,15 +154,13 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string // Wait for all downloads to complete wg.Wait() - close(errChan) + close(errCh) // Check for errors during downloads - for downloadErr := range errChan { - if downloadErr != nil { - // Clean up the directory and return the error - cleanupDirectory(downloadDir) - return nil, downloadErr - } + for downloadErr := range errCh { + // Clean up the directory and return the error + cleanupDirectory(downloadDir) + return nil, downloadErr } // Verify that tokenizer.json exists @@ -177,12 +170,7 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string } // Initialize the tokenizer using the downloaded tokenizer.json - tokenizer, err := FromFile(tokenizerPath) - if err != nil { - return nil, err - } - - return tokenizer, nil + return FromFile(tokenizerPath) } // downloadFile downloads a file from the given URL and saves it to the specified destination. From 0b1b908b6dce9d8647775611cd6259c614409fe2 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:01:38 -0800 Subject: [PATCH 03/10] fix: rename new func to FromPretrained, improve example --- example/main.go | 96 +++++++++++++++++++++++++++---------------------- tokenizer.go | 5 ++- 2 files changed, 56 insertions(+), 45 deletions(-) diff --git a/example/main.go b/example/main.go index 3c73732..5192894 100644 --- a/example/main.go +++ b/example/main.go @@ -2,13 +2,15 @@ package main import ( "fmt" + "log" + "github.com/daulet/tokenizers" ) -func main() { +func simple() error { tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json") if err != nil { - panic(err) + return err } // release native resources defer tk.Close() @@ -21,50 +23,60 @@ func main() { // [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]] fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog + return nil +} - var encodeOptions []tokenizers.EncodeOption - encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask()) +func advanced() error { + // Load tokenizer from local config file + tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json") + if err != nil { + return err + } + defer tk.Close() - // Or just basically - // encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) + // Load pretrained tokenizer from HuggingFace + tokenizerPath := "./.cache/tokenizers/google-bert/bert-base-uncased" + tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", &tokenizerPath, nil) + if err != nil { + return err + } + defer tkFromHf.Close() - encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) - fmt.Println(encodingResponse.IDs) - // [2829 4419 14523 2058 1996 13971 3899] - fmt.Println(encodingResponse.TypeIDs) - // [0 0 0 0 0 0 0] - fmt.Println(encodingResponse.SpecialTokensMask) - // [0 0 0 0 0 0 0] - fmt.Println(encodingResponse.AttentionMask) - // [1 1 1 1 1 1 1] - fmt.Println(encodingResponse.Tokens) - // [brown fox jumps over the lazy dog] - fmt.Println(encodingResponse.Offsets) - // [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]] + // Encode with specific options + encodeOptions := []tokenizers.EncodeOption{ + tokenizers.WithReturnTypeIDs(), + tokenizers.WithReturnAttentionMask(), + tokenizers.WithReturnTokens(), + tokenizers.WithReturnOffsets(), + tokenizers.WithReturnSpecialTokensMask(), + } + // Or simply: + // encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) - tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased" - tkFromHf, errHf := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil) - if errHf != nil { - panic(errHf) + // regardless of how the tokenizer was initialized, the output is the same + for _, tkzr := range []*tokenizers.Tokenizer{tk, tkFromHf} { + encodingResponse := tkzr.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) + fmt.Println(encodingResponse.IDs) + // [101 2829 4419 14523 2058 1996 13971 3899 102] + fmt.Println(encodingResponse.TypeIDs) + // [0 0 0 0 0 0 0 0 0] + fmt.Println(encodingResponse.SpecialTokensMask) + // [1 0 0 0 0 0 0 0 1] + fmt.Println(encodingResponse.AttentionMask) + // [1 1 1 1 1 1 1 1 1] + fmt.Println(encodingResponse.Tokens) + // [[CLS] brown fox jumps over the lazy dog [SEP]] + fmt.Println(encodingResponse.Offsets) + // [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]] } - // release native resources - defer tkFromHf.Close() + return nil +} - encodingResponseHf := tkFromHf.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) - fmt.Println(encodingResponseHf.IDs) - // [101 2829 4419 14523 2058 1996 13971 3899 102] - fmt.Println(encodingResponseHf.TypeIDs) - // [0 0 0 0 0 0 0 0 0] - fmt.Println(encodingResponseHf.SpecialTokensMask) - // [1 0 0 0 0 0 0 0 1] - fmt.Println(encodingResponseHf.AttentionMask) - // [1 1 1 1 1 1 1 1 1] - fmt.Println(encodingResponseHf.Tokens) - // [[CLS] brown fox jumps over the lazy dog [SEP]] - fmt.Println(encodingResponseHf.Offsets) - // [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]] +func main() { + if err := simple(); err != nil { + log.Fatal(err) + } + if err := advanced(); err != nil { + log.Fatal(err) + } } diff --git a/tokenizer.go b/tokenizer.go index f44fd6d..b0afdd9 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -97,13 +97,13 @@ func FromFile(path string) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } -// LoadTokenizerFromHuggingFace downloads necessary files and initializes the tokenizer. +// FromPretrained downloads necessary files and initializes the tokenizer. // Parameters: // - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased"). // - destination: Optional. If provided and not nil, files will be downloaded to this folder. // If nil, a temporary directory will be used. // - authToken: Optional. If provided and not nil, it will be used to authenticate requests. -func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string) (*Tokenizer, error) { +func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, error) { if strings.TrimSpace(modelID) == "" { return nil, fmt.Errorf("modelID cannot be empty") } @@ -179,7 +179,6 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string func downloadFile(url, destination string, authToken *string) error { // Check if the file already exists if _, err := os.Stat(destination); err == nil { - fmt.Printf("File %s already exists. Skipping download.\n", destination) return nil } From 70dbc185885dd4a271942bfb08b4c3918ee079da Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:09:39 -0800 Subject: [PATCH 04/10] fix: clean up downloadFile --- tokenizer.go | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tokenizer.go b/tokenizer.go index b0afdd9..b56ca7c 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -182,13 +182,6 @@ func downloadFile(url, destination string, authToken *string) error { return nil } - // Create the destination file - out, err := os.Create(destination) - if err != nil { - return fmt.Errorf("failed to create file %s: %w", destination, err) - } - defer out.Close() - // Create a new HTTP request req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -196,13 +189,11 @@ func downloadFile(url, destination string, authToken *string) error { } // If authToken is provided, set the Authorization header - if authToken != nil && *authToken != "" { + if authToken != nil { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *authToken)) } - // Perform the HTTP request - client := &http.Client{} - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("failed to download from %s: %w", url, err) } @@ -213,6 +204,13 @@ func downloadFile(url, destination string, authToken *string) error { return fmt.Errorf("failed to download from %s: status code %d", url, resp.StatusCode) } + // Create the destination file + out, err := os.Create(destination) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", destination, err) + } + defer out.Close() + // Write the response body to the file _, err = io.Copy(out, resp.Body) if err != nil { From 2f5c3fa4ea44856bd8456f5c1bd8d0a321468a59 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:19:20 -0800 Subject: [PATCH 05/10] fix: concurrency issues in case of an error --- tokenizer.go | 48 ++++++++++++++++-------------------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/tokenizer.go b/tokenizer.go index b56ca7c..264a68b 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -140,37 +140,31 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, fileURL := fmt.Sprintf("%s/%s", modelURL, fn) destPath := filepath.Join(downloadDir, fn) err := downloadFile(fileURL, destPath, authToken) - if err != nil { - if mandatory { - // If the file is mandatory, report an error - errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) - } else { - // Optional files: log warning and continue - fmt.Printf("Warning: failed to download optional file %s: %v\n", fn, err) - } + if err != nil && mandatory { + // If the file is mandatory, report an error + errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) } }(filename, isMandatory) } - // Wait for all downloads to complete - wg.Wait() - close(errCh) + go func() { + wg.Wait() + close(errCh) + }() - // Check for errors during downloads - for downloadErr := range errCh { - // Clean up the directory and return the error - cleanupDirectory(downloadDir) - return nil, downloadErr + var errs []error + for err := range errCh { + errs = append(errs, err) } - // Verify that tokenizer.json exists - tokenizerPath := filepath.Join(downloadDir, "tokenizer.json") - if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) { - return nil, fmt.Errorf("mandatory file tokenizer.json does not exist in %s", downloadDir) + if len(errs) > 0 { + if err := os.RemoveAll(downloadDir); err != nil { + fmt.Printf("Warning: failed to clean up directory %s: %v\n", downloadDir, err) + } + return nil, errs[0] } - // Initialize the tokenizer using the downloaded tokenizer.json - return FromFile(tokenizerPath) + return FromFile(filepath.Join(downloadDir, "tokenizer.json")) } // downloadFile downloads a file from the given URL and saves it to the specified destination. @@ -221,16 +215,6 @@ func downloadFile(url, destination string, authToken *string) error { return nil } -// cleanupDirectory removes the specified directory and all its contents. -func cleanupDirectory(dir string) { - err := os.RemoveAll(dir) - if err != nil { - fmt.Printf("Warning: failed to clean up directory %s: %v\n", dir, err) - } else { - fmt.Printf("Successfully cleaned up directory %s\n", dir) - } -} - func (t *Tokenizer) Close() error { C.free_tokenizer(t.tokenizer) t.tokenizer = nil From 52fc24ac4a3929d18073917160e974d30913f9a9 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:25:12 -0800 Subject: [PATCH 06/10] fix: make optional params optional --- example/main.go | 3 +-- tokenizer.go | 31 +++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/example/main.go b/example/main.go index 5192894..db65bf5 100644 --- a/example/main.go +++ b/example/main.go @@ -35,8 +35,7 @@ func advanced() error { defer tk.Close() // Load pretrained tokenizer from HuggingFace - tokenizerPath := "./.cache/tokenizers/google-bert/bert-base-uncased" - tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", &tokenizerPath, nil) + tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", tokenizers.WithCacheDir("./.cache/tokenizers")) if err != nil { return err } diff --git a/tokenizer.go b/tokenizer.go index 264a68b..c197ca8 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -97,13 +97,36 @@ func FromFile(path string) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } +type tokenizerConfig struct { + cacheDir *string + authToken *string +} + +type TokenizerConfigOption func(cfg *tokenizerConfig) + +func WithCacheDir(path string) TokenizerConfigOption { + return func(cfg *tokenizerConfig) { + cfg.cacheDir = &path + } +} + +func WithAuthToken(token string) TokenizerConfigOption { + return func(cfg *tokenizerConfig) { + cfg.authToken = &token + } +} + // FromPretrained downloads necessary files and initializes the tokenizer. // Parameters: // - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased"). // - destination: Optional. If provided and not nil, files will be downloaded to this folder. // If nil, a temporary directory will be used. // - authToken: Optional. If provided and not nil, it will be used to authenticate requests. -func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, error) { +func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, error) { + cfg := &tokenizerConfig{} + for _, opt := range opts { + opt(cfg) + } if strings.TrimSpace(modelID) == "" { return nil, fmt.Errorf("modelID cannot be empty") } @@ -113,8 +136,8 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, // Determine the download directory var downloadDir string - if destination != nil && *destination != "" { - downloadDir = *destination + if cfg.cacheDir != nil { + downloadDir = *cfg.cacheDir // Create the destination directory if it doesn't exist err := os.MkdirAll(downloadDir, os.ModePerm) if err != nil { @@ -139,7 +162,7 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, defer wg.Done() fileURL := fmt.Sprintf("%s/%s", modelURL, fn) destPath := filepath.Join(downloadDir, fn) - err := downloadFile(fileURL, destPath, authToken) + err := downloadFile(fileURL, destPath, cfg.authToken) if err != nil && mandatory { // If the file is mandatory, report an error errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) From 2c7471001fd6a3742cd9b2599da7c77701e505aa Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:28:46 -0800 Subject: [PATCH 07/10] fix: cache path has to be model specific --- tokenizer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizer.go b/tokenizer.go index c197ca8..765163b 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -137,7 +137,7 @@ func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, // Determine the download directory var downloadDir string if cfg.cacheDir != nil { - downloadDir = *cfg.cacheDir + downloadDir = fmt.Sprintf("%s/%s", *cfg.cacheDir, modelID) // Create the destination directory if it doesn't exist err := os.MkdirAll(downloadDir, os.ModePerm) if err != nil { From 6aae8ce13e92e38a65f119684a79d44e78c79c10 Mon Sep 17 00:00:00 2001 From: Berkay Ersoy Date: Tue, 5 Nov 2024 10:13:58 +0300 Subject: [PATCH 08/10] add unit tests for `FromPretrained` --- tokenizer_test.go | 272 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 270 insertions(+), 2 deletions(-) diff --git a/tokenizer_test.go b/tokenizer_test.go index da2c1fd..bfcf5de 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -2,11 +2,15 @@ package tokenizers_test import ( _ "embed" + "github.com/daulet/tokenizers" "math/rand" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" "testing" - "github.com/daulet/tokenizers" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -14,6 +18,23 @@ import ( //go:embed test/data/sentence-transformers-labse.json var embeddedBytes []byte +type mockTransport struct { + server *httptest.Server + modelID string +} + +func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + req.URL.Host = strings.TrimPrefix(t.server.URL, "http://") + + parts := strings.Split(req.URL.Path, "/") + if len(parts) > 2 { + req.URL.Path = "/" + t.modelID + "/resolve/main/" + parts[len(parts)-1] + } + + return t.server.Client().Transport.RoundTrip(req) +} + // TODO test for leaks func TestInvalidConfigPath(t *testing.T) { @@ -472,3 +493,250 @@ func BenchmarkDecodeNTokens(b *testing.B) { // a token is one or more characters assert.Greater(b, len(text), b.N) } + +func mockHuggingFaceServer(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/") + parts := strings.Split(path, "/") + + if len(parts) < 4 { + w.WriteHeader(http.StatusNotFound) + return + } + + fileName := parts[len(parts)-1] + + // Check authentication for private models + if strings.HasPrefix(path, "private/") { + authHeader := r.Header.Get("Authorization") + if !strings.Contains(authHeader, "test-token") { + w.WriteHeader(http.StatusUnauthorized) + return + } + } + + // For nonexistent model, only return 404 for tokenizer.json + if strings.Contains(path, "nonexistent") { + if fileName == "tokenizer.json" { + w.WriteHeader(http.StatusNotFound) + return + } + // Return empty response for optional files + w.WriteHeader(http.StatusOK) + w.Write([]byte("{}")) + return + } + + // Handle regular file requests + switch fileName { + case "tokenizer.json": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "type": "mock_tokenizer", + "vocab_size": 1000, + "model_max_length": 512 + }`)) + case "vocab.txt": + w.WriteHeader(http.StatusOK) + w.Write([]byte("[PAD]\n[UNK]\ntest\ntoken")) + case "special_tokens_map.json": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "pad_token": "[PAD]", + "unk_token": "[UNK]" + }`)) + default: + // Return empty response for other optional files + w.WriteHeader(http.StatusOK) + w.Write([]byte("{}")) + } + })) +} + +func TestFromPretrained(t *testing.T) { + server := mockHuggingFaceServer(t) + defer server.Close() + + tests := []struct { + name string + modelID string + setupOpts func() []tokenizers.TokenizerConfigOption + wantError bool + checkDir func(t *testing.T, dir string) // Add function to verify directory + }{ + { + name: "valid public model with cache dir", + modelID: "bert-base-uncased", + setupOpts: func() []tokenizers.TokenizerConfigOption { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + } + }, + wantError: false, + }, + { + name: "valid public model without cache dir", + modelID: "bert-base-uncased", + setupOpts: func() []tokenizers.TokenizerConfigOption { + return nil // No cache dir specified + }, + wantError: false, + }, + { + name: "private model with auth token and cache dir", + modelID: "private/model", + setupOpts: func() []tokenizers.TokenizerConfigOption { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + tokenizers.WithAuthToken("test-token"), + } + }, + wantError: false, + checkDir: func(t *testing.T, dir string) { + path := filepath.Join(dir, "private", "model", "tokenizer.json") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("expected tokenizer.json to exist in cache dir") + } + }, + }, + { + name: "private model with auth token without cache dir", + modelID: "private/model", + setupOpts: func() []tokenizers.TokenizerConfigOption { + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithAuthToken("test-token"), + } + }, + wantError: false, + checkDir: func(t *testing.T, dir string) { + if !strings.Contains(dir, "huggingface-tokenizer-") { + t.Errorf("expected temp directory name to contain 'huggingface-tokenizer-', got %s", dir) + } + }, + }, + { + name: "empty model ID", + modelID: "", + setupOpts: func() []tokenizers.TokenizerConfigOption { return nil }, + wantError: true, + checkDir: nil, // No directory check needed for error case + }, + { + name: "nonexistent model", + modelID: "nonexistent/model", + setupOpts: func() []tokenizers.TokenizerConfigOption { + tmpDir := t.TempDir() + return []tokenizers.TokenizerConfigOption{ + tokenizers.WithCacheDir(tmpDir), + } + }, + wantError: true, + checkDir: nil, // No directory check needed for error case + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origTransport := http.DefaultClient.Transport + http.DefaultClient.Transport = &mockTransport{ + server: server, + modelID: tt.modelID, + } + defer func() { http.DefaultClient.Transport = origTransport }() + + tokenizer, err := tokenizers.FromPretrained(tt.modelID, tt.setupOpts()...) + + if tt.wantError { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if tokenizer == nil { + t.Error("expected tokenizer, got nil") + return + } + + if err := tokenizer.Close(); err != nil { + t.Errorf("error closing tokenizer: %v", err) + } + }) + } +} + +func TestConfigOptions(t *testing.T) { + server := mockHuggingFaceServer(t) + defer server.Close() + + t.Run("WithCacheDir", func(t *testing.T) { + tmpDir := t.TempDir() + + origTransport := http.DefaultClient.Transport + http.DefaultClient.Transport = &mockTransport{ + server: server, + modelID: "bert-base-uncased", + } + defer func() { http.DefaultClient.Transport = origTransport }() + + tokenizer, err := tokenizers.FromPretrained( + "bert-base-uncased", + tokenizers.WithCacheDir(tmpDir), + ) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + defer tokenizer.Close() + + files := []string{"tokenizer.json", "vocab.txt", "special_tokens_map.json"} + for _, file := range files { + path := filepath.Join(tmpDir, "bert-base-uncased", file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", file) + } + } + }) + + t.Run("WithAuthToken", func(t *testing.T) { + origTransport := http.DefaultClient.Transport + http.DefaultClient.Transport = &mockTransport{ + server: server, + modelID: "private/model", + } + defer func() { http.DefaultClient.Transport = origTransport }() + + tokenizer, err := tokenizers.FromPretrained( + "private/model", + tokenizers.WithAuthToken("test-token"), + tokenizers.WithCacheDir(t.TempDir()), + ) + + if err != nil { + t.Errorf("unexpected error with valid auth token: %v", err) + return + } + if tokenizer != nil { + tokenizer.Close() + } + + tokenizer, err = tokenizers.FromPretrained( + "private/model", + tokenizers.WithAuthToken("invalid-token"), + tokenizers.WithCacheDir(t.TempDir()), + ) + + if err == nil { + t.Error("expected error with invalid auth token, got nil") + tokenizer.Close() + } + }) +} From 91ace109f1c606d6494bde2613f12410165e4767 Mon Sep 17 00:00:00 2001 From: Berkay Ersoy Date: Thu, 7 Nov 2024 10:11:38 +0300 Subject: [PATCH 09/10] migrate to table driven tests, unify/simplify test cases --- tokenizer_test.go | 287 +++++++++++----------------------------------- 1 file changed, 69 insertions(+), 218 deletions(-) diff --git a/tokenizer_test.go b/tokenizer_test.go index bfcf5de..1f6d223 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -4,11 +4,8 @@ import ( _ "embed" "github.com/daulet/tokenizers" "math/rand" - "net/http" - "net/http/httptest" "os" "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -18,23 +15,6 @@ import ( //go:embed test/data/sentence-transformers-labse.json var embeddedBytes []byte -type mockTransport struct { - server *httptest.Server - modelID string -} - -func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req.URL.Scheme = "http" - req.URL.Host = strings.TrimPrefix(t.server.URL, "http://") - - parts := strings.Split(req.URL.Path, "/") - if len(parts) > 2 { - req.URL.Path = "/" + t.modelID + "/resolve/main/" + parts[len(parts)-1] - } - - return t.server.Client().Transport.RoundTrip(req) -} - // TODO test for leaks func TestInvalidConfigPath(t *testing.T) { @@ -494,249 +474,120 @@ func BenchmarkDecodeNTokens(b *testing.B) { assert.Greater(b, len(text), b.N) } -func mockHuggingFaceServer(t *testing.T) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := strings.TrimPrefix(r.URL.Path, "/") - parts := strings.Split(path, "/") - - if len(parts) < 4 { - w.WriteHeader(http.StatusNotFound) - return - } - - fileName := parts[len(parts)-1] - - // Check authentication for private models - if strings.HasPrefix(path, "private/") { - authHeader := r.Header.Get("Authorization") - if !strings.Contains(authHeader, "test-token") { - w.WriteHeader(http.StatusUnauthorized) - return - } - } - - // For nonexistent model, only return 404 for tokenizer.json - if strings.Contains(path, "nonexistent") { - if fileName == "tokenizer.json" { - w.WriteHeader(http.StatusNotFound) - return - } - // Return empty response for optional files - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) - return - } - - // Handle regular file requests - switch fileName { - case "tokenizer.json": - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{ - "type": "mock_tokenizer", - "vocab_size": 1000, - "model_max_length": 512 - }`)) - case "vocab.txt": - w.WriteHeader(http.StatusOK) - w.Write([]byte("[PAD]\n[UNK]\ntest\ntoken")) - case "special_tokens_map.json": - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{ - "pad_token": "[PAD]", - "unk_token": "[UNK]" - }`)) - default: - // Return empty response for other optional files - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) - } - })) -} - func TestFromPretrained(t *testing.T) { - server := mockHuggingFaceServer(t) - defer server.Close() - tests := []struct { - name string - modelID string - setupOpts func() []tokenizers.TokenizerConfigOption - wantError bool - checkDir func(t *testing.T, dir string) // Add function to verify directory + name string + modelID string + setupOpts func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) + expectedError bool + expectedToken bool + skipCache bool }{ { - name: "valid public model with cache dir", - modelID: "bert-base-uncased", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "valid public model with cache dir", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), - } + }, tmpDir }, - wantError: false, }, { - name: "valid public model without cache dir", - modelID: "bert-base-uncased", - setupOpts: func() []tokenizers.TokenizerConfigOption { - return nil // No cache dir specified + name: "valid public model without cache dir", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + return nil, "" }, - wantError: false, + skipCache: true, }, { - name: "private model with auth token and cache dir", - modelID: "private/model", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "private model with valid auth token", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), tokenizers.WithAuthToken("test-token"), - } - }, - wantError: false, - checkDir: func(t *testing.T, dir string) { - path := filepath.Join(dir, "private", "model", "tokenizer.json") - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("expected tokenizer.json to exist in cache dir") - } + }, tmpDir }, }, { - name: "private model with auth token without cache dir", - modelID: "private/model", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "private model with invalid auth token", + modelID: "private-model", + expectedError: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ - tokenizers.WithAuthToken("test-token"), - } - }, - wantError: false, - checkDir: func(t *testing.T, dir string) { - if !strings.Contains(dir, "huggingface-tokenizer-") { - t.Errorf("expected temp directory name to contain 'huggingface-tokenizer-', got %s", dir) - } + tokenizers.WithCacheDir(tmpDir), + tokenizers.WithAuthToken("invalid-token"), + }, tmpDir }, + skipCache: true, }, { - name: "empty model ID", - modelID: "", - setupOpts: func() []tokenizers.TokenizerConfigOption { return nil }, - wantError: true, - checkDir: nil, // No directory check needed for error case + name: "empty model ID", + modelID: "", + expectedError: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + return nil, "" + }, + skipCache: true, }, { - name: "nonexistent model", - modelID: "nonexistent/model", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "nonexistent model", + modelID: "nonexistent/model", + expectedError: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), - } + }, tmpDir }, - wantError: true, - checkDir: nil, // No directory check needed for error case + skipCache: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - origTransport := http.DefaultClient.Transport - http.DefaultClient.Transport = &mockTransport{ - server: server, - modelID: tt.modelID, - } - defer func() { http.DefaultClient.Transport = origTransport }() - - tokenizer, err := tokenizers.FromPretrained(tt.modelID, tt.setupOpts()...) + opts, cacheDir := tt.setupOpts(t) + tokenizer, err := tokenizers.FromPretrained(tt.modelID, opts...) - if tt.wantError { - if err == nil { - t.Error("expected error, got nil") - } - return + if tt.expectedError && err == nil { + t.Fatalf("expected error for case %s, got nil", tt.name) } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return + if !tt.expectedError && err != nil { + t.Fatalf("unexpected error for case %s: %v", tt.name, err) } - - if tokenizer == nil { - t.Error("expected tokenizer, got nil") + if !tt.expectedToken && tokenizer != nil { + t.Fatalf("expected nil tokenizer for case %s, got non-nil", tt.name) + } + if tt.expectedToken && tokenizer == nil { + t.Fatalf("expected non-nil tokenizer for case %s", tt.name) + } + if tt.expectedError { return } - + if !tt.skipCache && cacheDir != "" { + validateCache(t, cacheDir, tt.modelID) + } if err := tokenizer.Close(); err != nil { - t.Errorf("error closing tokenizer: %v", err) + t.Fatalf("error closing tokenizer: %v", err) } }) } } -func TestConfigOptions(t *testing.T) { - server := mockHuggingFaceServer(t) - defer server.Close() - - t.Run("WithCacheDir", func(t *testing.T) { - tmpDir := t.TempDir() - - origTransport := http.DefaultClient.Transport - http.DefaultClient.Transport = &mockTransport{ - server: server, - modelID: "bert-base-uncased", - } - defer func() { http.DefaultClient.Transport = origTransport }() - - tokenizer, err := tokenizers.FromPretrained( - "bert-base-uncased", - tokenizers.WithCacheDir(tmpDir), - ) - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - defer tokenizer.Close() - - files := []string{"tokenizer.json", "vocab.txt", "special_tokens_map.json"} - for _, file := range files { - path := filepath.Join(tmpDir, "bert-base-uncased", file) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("expected file %s to exist", file) - } - } - }) - - t.Run("WithAuthToken", func(t *testing.T) { - origTransport := http.DefaultClient.Transport - http.DefaultClient.Transport = &mockTransport{ - server: server, - modelID: "private/model", - } - defer func() { http.DefaultClient.Transport = origTransport }() - - tokenizer, err := tokenizers.FromPretrained( - "private/model", - tokenizers.WithAuthToken("test-token"), - tokenizers.WithCacheDir(t.TempDir()), - ) - - if err != nil { - t.Errorf("unexpected error with valid auth token: %v", err) - return +func validateCache(t *testing.T, dir string, modelID string) { + t.Helper() + files := []string{"tokenizer.json", "vocab.txt"} + for _, file := range files { + path := filepath.Join(dir, modelID, file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatalf("expected file %s to exist in cache for model %s", file, modelID) } - if tokenizer != nil { - tokenizer.Close() - } - - tokenizer, err = tokenizers.FromPretrained( - "private/model", - tokenizers.WithAuthToken("invalid-token"), - tokenizers.WithCacheDir(t.TempDir()), - ) - - if err == nil { - t.Error("expected error with invalid auth token, got nil") - tokenizer.Close() - } - }) + } } From 5d79518c2d0e6de0b5a04bacde8fadc4eb0e6005 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Thu, 7 Nov 2024 13:03:03 -0800 Subject: [PATCH 10/10] fix: clean up nits --- README.md | 2 +- tokenizer_test.go | 49 +++++++++++++++++------------------------------ 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6883f71..2ea06a3 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Go bindings for the [HuggingFace Tokenizers](https://github.com/huggingface/toke ## Installation -`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers-directory' -ltokenizers'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers-directory" -ltokenizers` to avoid specifying it every time. +`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers/directory'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers/directory"` to avoid specifying it every time. ### Using pre-built binaries diff --git a/tokenizer_test.go b/tokenizer_test.go index 1f6d223..62547d5 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -2,12 +2,13 @@ package tokenizers_test import ( _ "embed" - "github.com/daulet/tokenizers" "math/rand" "os" "path/filepath" "testing" + "github.com/daulet/tokenizers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -479,9 +480,8 @@ func TestFromPretrained(t *testing.T) { name string modelID string setupOpts func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) - expectedError bool + wantErr bool expectedToken bool - skipCache bool }{ { name: "valid public model with cache dir", @@ -501,7 +501,6 @@ func TestFromPretrained(t *testing.T) { setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { return nil, "" }, - skipCache: true, }, { name: "private model with valid auth token", @@ -516,9 +515,9 @@ func TestFromPretrained(t *testing.T) { }, }, { - name: "private model with invalid auth token", - modelID: "private-model", - expectedError: true, + name: "private model with invalid auth token", + modelID: "private-model", + wantErr: true, setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ @@ -526,28 +525,25 @@ func TestFromPretrained(t *testing.T) { tokenizers.WithAuthToken("invalid-token"), }, tmpDir }, - skipCache: true, }, { - name: "empty model ID", - modelID: "", - expectedError: true, + name: "empty model ID", + modelID: "", + wantErr: true, setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { return nil, "" }, - skipCache: true, }, { - name: "nonexistent model", - modelID: "nonexistent/model", - expectedError: true, + name: "nonexistent model", + modelID: "nonexistent/model", + wantErr: true, setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), }, tmpDir }, - skipCache: true, }, } @@ -556,22 +552,13 @@ func TestFromPretrained(t *testing.T) { opts, cacheDir := tt.setupOpts(t) tokenizer, err := tokenizers.FromPretrained(tt.modelID, opts...) - if tt.expectedError && err == nil { - t.Fatalf("expected error for case %s, got nil", tt.name) - } - if !tt.expectedError && err != nil { - t.Fatalf("unexpected error for case %s: %v", tt.name, err) - } - if !tt.expectedToken && tokenizer != nil { - t.Fatalf("expected nil tokenizer for case %s, got non-nil", tt.name) - } - if tt.expectedToken && tokenizer == nil { - t.Fatalf("expected non-nil tokenizer for case %s", tt.name) + if gotErr := err != nil; gotErr != tt.wantErr { + t.Fatalf("expected error: %v, got error: %v", tt.wantErr, err) } - if tt.expectedError { + if tt.wantErr { return } - if !tt.skipCache && cacheDir != "" { + if cacheDir != "" { validateCache(t, cacheDir, tt.modelID) } if err := tokenizer.Close(); err != nil { @@ -586,8 +573,8 @@ func validateCache(t *testing.T, dir string, modelID string) { files := []string{"tokenizer.json", "vocab.txt"} for _, file := range files { path := filepath.Join(dir, modelID, file) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Fatalf("expected file %s to exist in cache for model %s", file, modelID) + if _, err := os.Stat(path); err != nil { + t.Errorf("expected file %s to exist in cache for model %s", file, modelID) } } }