From 53ff61bf88e90b69d22f18115668e144db7f14b9 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:25:12 -0800 Subject: [PATCH] fix: make optional params optional --- example/main.go | 3 +-- tokenizer.go | 31 +++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/example/main.go b/example/main.go index 5192894..db65bf5 100644 --- a/example/main.go +++ b/example/main.go @@ -35,8 +35,7 @@ func advanced() error { defer tk.Close() // Load pretrained tokenizer from HuggingFace - tokenizerPath := "./.cache/tokenizers/google-bert/bert-base-uncased" - tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", &tokenizerPath, nil) + tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", tokenizers.WithCacheDir("./.cache/tokenizers")) if err != nil { return err } diff --git a/tokenizer.go b/tokenizer.go index c4ac143..17d2277 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -88,13 +88,36 @@ func FromFile(path string) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } +type tokenizerConfig struct { + cacheDir *string + authToken *string +} + +type TokenizerConfigOption func(cfg *tokenizerConfig) + +func WithCacheDir(path string) TokenizerConfigOption { + return func(cfg *tokenizerConfig) { + cfg.cacheDir = &path + } +} + +func WithAuthToken(token string) TokenizerConfigOption { + return func(cfg *tokenizerConfig) { + cfg.authToken = &token + } +} + // FromPretrained downloads necessary files and initializes the tokenizer. // Parameters: // - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased"). // - destination: Optional. If provided and not nil, files will be downloaded to this folder. // If nil, a temporary directory will be used. // - authToken: Optional. If provided and not nil, it will be used to authenticate requests. -func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, error) { +func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, error) { + cfg := &tokenizerConfig{} + for _, opt := range opts { + opt(cfg) + } if strings.TrimSpace(modelID) == "" { return nil, fmt.Errorf("modelID cannot be empty") } @@ -104,8 +127,8 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, // Determine the download directory var downloadDir string - if destination != nil && *destination != "" { - downloadDir = *destination + if cfg.cacheDir != nil { + downloadDir = *cfg.cacheDir // Create the destination directory if it doesn't exist err := os.MkdirAll(downloadDir, os.ModePerm) if err != nil { @@ -130,7 +153,7 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, defer wg.Done() fileURL := fmt.Sprintf("%s/%s", modelURL, fn) destPath := filepath.Join(downloadDir, fn) - err := downloadFile(fileURL, destPath, authToken) + err := downloadFile(fileURL, destPath, cfg.authToken) if err != nil && mandatory { // If the file is mandatory, report an error errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err)