Skip to content

Commit

Permalink
migrate to table driven tests, unify/simplify test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
berkayersoyy committed Nov 7, 2024
1 parent 9f2fd96 commit d8c5cab
Showing 1 changed file with 69 additions and 218 deletions.
287 changes: 69 additions & 218 deletions tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ import (
_ "embed"
"github.com/daulet/tokenizers"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -18,23 +15,6 @@ import (
//go:embed test/data/sentence-transformers-labse.json
var embeddedBytes []byte

type mockTransport struct {
server *httptest.Server
modelID string
}

func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
req.URL.Host = strings.TrimPrefix(t.server.URL, "http://")

parts := strings.Split(req.URL.Path, "/")
if len(parts) > 2 {
req.URL.Path = "/" + t.modelID + "/resolve/main/" + parts[len(parts)-1]
}

return t.server.Client().Transport.RoundTrip(req)
}

// TODO test for leaks

func TestInvalidConfigPath(t *testing.T) {
Expand Down Expand Up @@ -494,249 +474,120 @@ func BenchmarkDecodeNTokens(b *testing.B) {
assert.Greater(b, len(text), b.N)
}

func mockHuggingFaceServer(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/")
parts := strings.Split(path, "/")

if len(parts) < 4 {
w.WriteHeader(http.StatusNotFound)
return
}

fileName := parts[len(parts)-1]

// Check authentication for private models
if strings.HasPrefix(path, "private/") {
authHeader := r.Header.Get("Authorization")
if !strings.Contains(authHeader, "test-token") {
w.WriteHeader(http.StatusUnauthorized)
return
}
}

// For nonexistent model, only return 404 for tokenizer.json
if strings.Contains(path, "nonexistent") {
if fileName == "tokenizer.json" {
w.WriteHeader(http.StatusNotFound)
return
}
// Return empty response for optional files
w.WriteHeader(http.StatusOK)
w.Write([]byte("{}"))
return
}

// Handle regular file requests
switch fileName {
case "tokenizer.json":
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{
"type": "mock_tokenizer",
"vocab_size": 1000,
"model_max_length": 512
}`))
case "vocab.txt":
w.WriteHeader(http.StatusOK)
w.Write([]byte("[PAD]\n[UNK]\ntest\ntoken"))
case "special_tokens_map.json":
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{
"pad_token": "[PAD]",
"unk_token": "[UNK]"
}`))
default:
// Return empty response for other optional files
w.WriteHeader(http.StatusOK)
w.Write([]byte("{}"))
}
}))
}

func TestFromPretrained(t *testing.T) {
server := mockHuggingFaceServer(t)
defer server.Close()

tests := []struct {
name string
modelID string
setupOpts func() []tokenizers.TokenizerConfigOption
wantError bool
checkDir func(t *testing.T, dir string) // Add function to verify directory
name string
modelID string
setupOpts func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string)
expectedError bool
expectedToken bool
skipCache bool
}{
{
name: "valid public model with cache dir",
modelID: "bert-base-uncased",
setupOpts: func() []tokenizers.TokenizerConfigOption {
name: "valid public model with cache dir",
modelID: "bert-base-uncased",
expectedToken: true,
setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) {
tmpDir := t.TempDir()
return []tokenizers.TokenizerConfigOption{
tokenizers.WithCacheDir(tmpDir),
}
}, tmpDir
},
wantError: false,
},
{
name: "valid public model without cache dir",
modelID: "bert-base-uncased",
setupOpts: func() []tokenizers.TokenizerConfigOption {
return nil // No cache dir specified
name: "valid public model without cache dir",
modelID: "bert-base-uncased",
expectedToken: true,
setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) {
return nil, ""
},
wantError: false,
skipCache: true,
},
{
name: "private model with auth token and cache dir",
modelID: "private/model",
setupOpts: func() []tokenizers.TokenizerConfigOption {
name: "private model with valid auth token",
modelID: "bert-base-uncased",
expectedToken: true,
setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) {
tmpDir := t.TempDir()
return []tokenizers.TokenizerConfigOption{
tokenizers.WithCacheDir(tmpDir),
tokenizers.WithAuthToken("test-token"),
}
},
wantError: false,
checkDir: func(t *testing.T, dir string) {
path := filepath.Join(dir, "private", "model", "tokenizer.json")
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Errorf("expected tokenizer.json to exist in cache dir")
}
}, tmpDir
},
},
{
name: "private model with auth token without cache dir",
modelID: "private/model",
setupOpts: func() []tokenizers.TokenizerConfigOption {
name: "private model with invalid auth token",
modelID: "private-model",
expectedError: true,
setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) {
tmpDir := t.TempDir()
return []tokenizers.TokenizerConfigOption{
tokenizers.WithAuthToken("test-token"),
}
},
wantError: false,
checkDir: func(t *testing.T, dir string) {
if !strings.Contains(dir, "huggingface-tokenizer-") {
t.Errorf("expected temp directory name to contain 'huggingface-tokenizer-', got %s", dir)
}
tokenizers.WithCacheDir(tmpDir),
tokenizers.WithAuthToken("invalid-token"),
}, tmpDir
},
skipCache: true,
},
{
name: "empty model ID",
modelID: "",
setupOpts: func() []tokenizers.TokenizerConfigOption { return nil },
wantError: true,
checkDir: nil, // No directory check needed for error case
name: "empty model ID",
modelID: "",
expectedError: true,
setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) {
return nil, ""
},
skipCache: true,
},
{
name: "nonexistent model",
modelID: "nonexistent/model",
setupOpts: func() []tokenizers.TokenizerConfigOption {
name: "nonexistent model",
modelID: "nonexistent/model",
expectedError: true,
setupOpts: func(t *testing.T) ([]tokenizers.TokenizerConfigOption, string) {
tmpDir := t.TempDir()
return []tokenizers.TokenizerConfigOption{
tokenizers.WithCacheDir(tmpDir),
}
}, tmpDir
},
wantError: true,
checkDir: nil, // No directory check needed for error case
skipCache: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origTransport := http.DefaultClient.Transport
http.DefaultClient.Transport = &mockTransport{
server: server,
modelID: tt.modelID,
}
defer func() { http.DefaultClient.Transport = origTransport }()

tokenizer, err := tokenizers.FromPretrained(tt.modelID, tt.setupOpts()...)
opts, cacheDir := tt.setupOpts(t)
tokenizer, err := tokenizers.FromPretrained(tt.modelID, opts...)

if tt.wantError {
if err == nil {
t.Error("expected error, got nil")
}
return
if tt.expectedError && err == nil {
t.Fatalf("expected error for case %s, got nil", tt.name)
}

if err != nil {
t.Errorf("unexpected error: %v", err)
return
if !tt.expectedError && err != nil {
t.Fatalf("unexpected error for case %s: %v", tt.name, err)
}

if tokenizer == nil {
t.Error("expected tokenizer, got nil")
if !tt.expectedToken && tokenizer != nil {
t.Fatalf("expected nil tokenizer for case %s, got non-nil", tt.name)
}
if tt.expectedToken && tokenizer == nil {
t.Fatalf("expected non-nil tokenizer for case %s", tt.name)
}
if tt.expectedError {
return
}

if !tt.skipCache && cacheDir != "" {
validateCache(t, cacheDir, tt.modelID)
}
if err := tokenizer.Close(); err != nil {
t.Errorf("error closing tokenizer: %v", err)
t.Fatalf("error closing tokenizer: %v", err)
}
})
}
}

func TestConfigOptions(t *testing.T) {
server := mockHuggingFaceServer(t)
defer server.Close()

t.Run("WithCacheDir", func(t *testing.T) {
tmpDir := t.TempDir()

origTransport := http.DefaultClient.Transport
http.DefaultClient.Transport = &mockTransport{
server: server,
modelID: "bert-base-uncased",
}
defer func() { http.DefaultClient.Transport = origTransport }()

tokenizer, err := tokenizers.FromPretrained(
"bert-base-uncased",
tokenizers.WithCacheDir(tmpDir),
)

if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
defer tokenizer.Close()

files := []string{"tokenizer.json", "vocab.txt", "special_tokens_map.json"}
for _, file := range files {
path := filepath.Join(tmpDir, "bert-base-uncased", file)
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Errorf("expected file %s to exist", file)
}
}
})

t.Run("WithAuthToken", func(t *testing.T) {
origTransport := http.DefaultClient.Transport
http.DefaultClient.Transport = &mockTransport{
server: server,
modelID: "private/model",
}
defer func() { http.DefaultClient.Transport = origTransport }()

tokenizer, err := tokenizers.FromPretrained(
"private/model",
tokenizers.WithAuthToken("test-token"),
tokenizers.WithCacheDir(t.TempDir()),
)

if err != nil {
t.Errorf("unexpected error with valid auth token: %v", err)
return
func validateCache(t *testing.T, dir string, modelID string) {
t.Helper()
files := []string{"tokenizer.json", "vocab.txt"}
for _, file := range files {
path := filepath.Join(dir, modelID, file)
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Fatalf("expected file %s to exist in cache for model %s", file, modelID)
}
if tokenizer != nil {
tokenizer.Close()
}

tokenizer, err = tokenizers.FromPretrained(
"private/model",
tokenizers.WithAuthToken("invalid-token"),
tokenizers.WithCacheDir(t.TempDir()),
)

if err == nil {
t.Error("expected error with invalid auth token, got nil")
tokenizer.Close()
}
})
}
}

0 comments on commit d8c5cab

Please sign in to comment.