From d8c5cab8fbd7578b87c0f806c55fd3f2e83f5c4c Mon Sep 17 00:00:00 2001 From: Berkay Ersoy Date: Thu, 7 Nov 2024 10:11:38 +0300 Subject: [PATCH] migrate to table driven tests, unify/simplify test cases --- tokenizer_test.go | 287 +++++++++++----------------------------------- 1 file changed, 69 insertions(+), 218 deletions(-) diff --git a/tokenizer_test.go b/tokenizer_test.go index bfcf5de..1f6d223 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -4,11 +4,8 @@ import ( _ "embed" "github.com/daulet/tokenizers" "math/rand" - "net/http" - "net/http/httptest" "os" "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -18,23 +15,6 @@ import ( //go:embed test/data/sentence-transformers-labse.json var embeddedBytes []byte -type mockTransport struct { - server *httptest.Server - modelID string -} - -func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req.URL.Scheme = "http" - req.URL.Host = strings.TrimPrefix(t.server.URL, "http://") - - parts := strings.Split(req.URL.Path, "/") - if len(parts) > 2 { - req.URL.Path = "/" + t.modelID + "/resolve/main/" + parts[len(parts)-1] - } - - return t.server.Client().Transport.RoundTrip(req) -} - // TODO test for leaks func TestInvalidConfigPath(t *testing.T) { @@ -494,249 +474,120 @@ func BenchmarkDecodeNTokens(b *testing.B) { assert.Greater(b, len(text), b.N) } -func mockHuggingFaceServer(t *testing.T) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := strings.TrimPrefix(r.URL.Path, "/") - parts := strings.Split(path, "/") - - if len(parts) < 4 { - w.WriteHeader(http.StatusNotFound) - return - } - - fileName := parts[len(parts)-1] - - // Check authentication for private models - if strings.HasPrefix(path, "private/") { - authHeader := r.Header.Get("Authorization") - if !strings.Contains(authHeader, "test-token") { - w.WriteHeader(http.StatusUnauthorized) - return - } - } - - // For nonexistent model, only return 404 for tokenizer.json - if strings.Contains(path, "nonexistent") { - if fileName == "tokenizer.json" { - w.WriteHeader(http.StatusNotFound) - return - } - // Return empty response for optional files - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) - return - } - - // Handle regular file requests - switch fileName { - case "tokenizer.json": - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{ - "type": "mock_tokenizer", - "vocab_size": 1000, - "model_max_length": 512 - }`)) - case "vocab.txt": - w.WriteHeader(http.StatusOK) - w.Write([]byte("[PAD]\n[UNK]\ntest\ntoken")) - case "special_tokens_map.json": - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{ - "pad_token": "[PAD]", - "unk_token": "[UNK]" - }`)) - default: - // Return empty response for other optional files - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) - } - })) -} - func TestFromPretrained(t *testing.T) { - server := mockHuggingFaceServer(t) - defer server.Close() - tests := []struct { - name string - modelID string - setupOpts func() []tokenizers.TokenizerConfigOption - wantError bool - checkDir func(t *testing.T, dir string) // Add function to verify directory + name string + modelID string + setupOpts func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) + expectedError bool + expectedToken bool + skipCache bool }{ { - name: "valid public model with cache dir", - modelID: "bert-base-uncased", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "valid public model with cache dir", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), - } + }, tmpDir }, - wantError: false, }, { - name: "valid public model without cache dir", - modelID: "bert-base-uncased", - setupOpts: func() []tokenizers.TokenizerConfigOption { - return nil // No cache dir specified + name: "valid public model without cache dir", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + return nil, "" }, - wantError: false, + skipCache: true, }, { - name: "private model with auth token and cache dir", - modelID: "private/model", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "private model with valid auth token", + modelID: "bert-base-uncased", + expectedToken: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), tokenizers.WithAuthToken("test-token"), - } - }, - wantError: false, - checkDir: func(t *testing.T, dir string) { - path := filepath.Join(dir, "private", "model", "tokenizer.json") - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("expected tokenizer.json to exist in cache dir") - } + }, tmpDir }, }, { - name: "private model with auth token without cache dir", - modelID: "private/model", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "private model with invalid auth token", + modelID: "private-model", + expectedError: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ - tokenizers.WithAuthToken("test-token"), - } - }, - wantError: false, - checkDir: func(t *testing.T, dir string) { - if !strings.Contains(dir, "huggingface-tokenizer-") { - t.Errorf("expected temp directory name to contain 'huggingface-tokenizer-', got %s", dir) - } + tokenizers.WithCacheDir(tmpDir), + tokenizers.WithAuthToken("invalid-token"), + }, tmpDir }, + skipCache: true, }, { - name: "empty model ID", - modelID: "", - setupOpts: func() []tokenizers.TokenizerConfigOption { return nil }, - wantError: true, - checkDir: nil, // No directory check needed for error case + name: "empty model ID", + modelID: "", + expectedError: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { + return nil, "" + }, + skipCache: true, }, { - name: "nonexistent model", - modelID: "nonexistent/model", - setupOpts: func() []tokenizers.TokenizerConfigOption { + name: "nonexistent model", + modelID: "nonexistent/model", + expectedError: true, + setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) { tmpDir := t.TempDir() return []tokenizers.TokenizerConfigOption{ tokenizers.WithCacheDir(tmpDir), - } + }, tmpDir }, - wantError: true, - checkDir: nil, // No directory check needed for error case + skipCache: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - origTransport := http.DefaultClient.Transport - http.DefaultClient.Transport = &mockTransport{ - server: server, - modelID: tt.modelID, - } - defer func() { http.DefaultClient.Transport = origTransport }() - - tokenizer, err := tokenizers.FromPretrained(tt.modelID, tt.setupOpts()...) + opts, cacheDir := tt.setupOpts(t) + tokenizer, err := tokenizers.FromPretrained(tt.modelID, opts...) - if tt.wantError { - if err == nil { - t.Error("expected error, got nil") - } - return + if tt.expectedError && err == nil { + t.Fatalf("expected error for case %s, got nil", tt.name) } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return + if !tt.expectedError && err != nil { + t.Fatalf("unexpected error for case %s: %v", tt.name, err) } - - if tokenizer == nil { - t.Error("expected tokenizer, got nil") + if !tt.expectedToken && tokenizer != nil { + t.Fatalf("expected nil tokenizer for case %s, got non-nil", tt.name) + } + if tt.expectedToken && tokenizer == nil { + t.Fatalf("expected non-nil tokenizer for case %s", tt.name) + } + if tt.expectedError { return } - + if !tt.skipCache && cacheDir != "" { + validateCache(t, cacheDir, tt.modelID) + } if err := tokenizer.Close(); err != nil { - t.Errorf("error closing tokenizer: %v", err) + t.Fatalf("error closing tokenizer: %v", err) } }) } } -func TestConfigOptions(t *testing.T) { - server := mockHuggingFaceServer(t) - defer server.Close() - - t.Run("WithCacheDir", func(t *testing.T) { - tmpDir := t.TempDir() - - origTransport := http.DefaultClient.Transport - http.DefaultClient.Transport = &mockTransport{ - server: server, - modelID: "bert-base-uncased", - } - defer func() { http.DefaultClient.Transport = origTransport }() - - tokenizer, err := tokenizers.FromPretrained( - "bert-base-uncased", - tokenizers.WithCacheDir(tmpDir), - ) - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - defer tokenizer.Close() - - files := []string{"tokenizer.json", "vocab.txt", "special_tokens_map.json"} - for _, file := range files { - path := filepath.Join(tmpDir, "bert-base-uncased", file) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("expected file %s to exist", file) - } - } - }) - - t.Run("WithAuthToken", func(t *testing.T) { - origTransport := http.DefaultClient.Transport - http.DefaultClient.Transport = &mockTransport{ - server: server, - modelID: "private/model", - } - defer func() { http.DefaultClient.Transport = origTransport }() - - tokenizer, err := tokenizers.FromPretrained( - "private/model", - tokenizers.WithAuthToken("test-token"), - tokenizers.WithCacheDir(t.TempDir()), - ) - - if err != nil { - t.Errorf("unexpected error with valid auth token: %v", err) - return +func validateCache(t *testing.T, dir string, modelID string) { + t.Helper() + files := []string{"tokenizer.json", "vocab.txt"} + for _, file := range files { + path := filepath.Join(dir, modelID, file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatalf("expected file %s to exist in cache for model %s", file, modelID) } - if tokenizer != nil { - tokenizer.Close() - } - - tokenizer, err = tokenizers.FromPretrained( - "private/model", - tokenizers.WithAuthToken("invalid-token"), - tokenizers.WithCacheDir(t.TempDir()), - ) - - if err == nil { - t.Error("expected error with invalid auth token, got nil") - tokenizer.Close() - } - }) + } }