From 5d79518c2d0e6de0b5a04bacde8fadc4eb0e6005 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Thu, 7 Nov 2024 13:03:03 -0800 Subject: [PATCH] fix: clean up nits --- README.md | 2 +- tokenizer_test.go | 49 +++++++++++++++++------------------------------ 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6883f71..2ea06a3 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Go bindings for the [HuggingFace Tokenizers](https://github.com/huggingface/toke ## Installation -`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers-directory' -ltokenizers'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers-directory" -ltokenizers` to avoid specifying it every time. +`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers/directory'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers/directory"` to avoid specifying it every time. ### Using pre-built binaries diff --git a/tokenizer_test.go b/tokenizer_test.go index 1f6d223..62547d5 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -2,12 +2,13 @@ package tokenizers_test import ( _ "embed" - "github.com/daulet/tokenizers" "math/rand" "os" "path/filepath" "testing" + "github.com/daulet/tokenizers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -479,9 +480,8 @@ func TestFromPretrained(t *testing.T) { name string modelID string setupOpts func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) - expectedError bool + wantErr bool expectedToken bool - skipCache bool }{ { name: "valid public model with cache dir", @@ -501,7 +501,6 @@ func TestFromPretrained(t *testing.T) { setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { return nil, "" }, - skipCache: true, }, { name: "private model with valid auth token", @@ -516,9 +515,9 @@ func TestFromPretrained(t *testing.T) { }, }, { - name: "private model with invalid auth token", - modelID: "private-model", - expectedError: true, + name: "private model with invalid auth token", + modelID: "private-model", + wantErr: true, setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ @@ -526,28 +525,25 @@ func TestFromPretrained(t *testing.T) { tokenizers.WithAuthToken("invalid-token"), }, tmpDir }, - skipCache: true, }, { - name: "empty model ID", - modelID: "", - expectedError: true, + name: "empty model ID", + modelID: "", + wantErr: true, setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { return nil, "" }, - skipCache: true, }, { - name: "nonexistent model", - modelID: "nonexistent/model", - expectedError: true, + name: "nonexistent model", + modelID: "nonexistent/model", + wantErr: true, setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), }, tmpDir }, - skipCache: true, }, } @@ -556,22 +552,13 @@ func TestFromPretrained(t *testing.T) { opts, cacheDir := tt.setupOpts(t) tokenizer, err := tokenizers.FromPretrained(tt.modelID, opts...) - if tt.expectedError && err == nil { - t.Fatalf("expected error for case %s, got nil", tt.name) - } - if !tt.expectedError && err != nil { - t.Fatalf("unexpected error for case %s: %v", tt.name, err) - } - if !tt.expectedToken && tokenizer != nil { - t.Fatalf("expected nil tokenizer for case %s, got non-nil", tt.name) - } - if tt.expectedToken && tokenizer == nil { - t.Fatalf("expected non-nil tokenizer for case %s", tt.name) + if gotErr := err != nil; gotErr != tt.wantErr { + t.Fatalf("expected error: %v, got error: %v", tt.wantErr, err) } - if tt.expectedError { + if tt.wantErr { return } - if !tt.skipCache && cacheDir != "" { + if cacheDir != "" { validateCache(t, cacheDir, tt.modelID) } if err := tokenizer.Close(); err != nil { @@ -586,8 +573,8 @@ func validateCache(t *testing.T, dir string, modelID string) { files := []string{"tokenizer.json", "vocab.txt"} for _, file := range files { path := filepath.Join(dir, modelID, file) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Fatalf("expected file %s to exist in cache for model %s", file, modelID) + if _, err := os.Stat(path); err != nil { + t.Errorf("expected file %s to exist in cache for model %s", file, modelID) } } }