Skip to content

Commit

Permalink
fix: make optional params optional
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Nov 5, 2024
1 parent 8732d30 commit 53ff61b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
3 changes: 1 addition & 2 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ func advanced() error {
defer tk.Close()

// Load pretrained tokenizer from HuggingFace
tokenizerPath := "./.cache/tokenizers/google-bert/bert-base-uncased"
tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", &tokenizerPath, nil)
tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", tokenizers.WithCacheDir("./.cache/tokenizers"))
if err != nil {
return err
}
Expand Down
31 changes: 27 additions & 4 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,36 @@ func FromFile(path string) (*Tokenizer, error) {
return &Tokenizer{tokenizer: tokenizer}, nil
}

type tokenizerConfig struct {
cacheDir *string
authToken *string
}

type TokenizerConfigOption func(cfg *tokenizerConfig)

func WithCacheDir(path string) TokenizerConfigOption {
return func(cfg *tokenizerConfig) {
cfg.cacheDir = &path
}
}

func WithAuthToken(token string) TokenizerConfigOption {
return func(cfg *tokenizerConfig) {
cfg.authToken = &token
}
}

// FromPretrained downloads necessary files and initializes the tokenizer.
// Parameters:
// - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased").
// - destination: Optional. If provided and not nil, files will be downloaded to this folder.
// If nil, a temporary directory will be used.
// - authToken: Optional. If provided and not nil, it will be used to authenticate requests.
func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, error) {
func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, error) {
cfg := &tokenizerConfig{}
for _, opt := range opts {
opt(cfg)
}
if strings.TrimSpace(modelID) == "" {
return nil, fmt.Errorf("modelID cannot be empty")
}
Expand All @@ -104,8 +127,8 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer,

// Determine the download directory
var downloadDir string
if destination != nil && *destination != "" {
downloadDir = *destination
if cfg.cacheDir != nil {
downloadDir = *cfg.cacheDir
// Create the destination directory if it doesn't exist
err := os.MkdirAll(downloadDir, os.ModePerm)
if err != nil {
Expand All @@ -130,7 +153,7 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer,
defer wg.Done()
fileURL := fmt.Sprintf("%s/%s", modelURL, fn)
destPath := filepath.Join(downloadDir, fn)
err := downloadFile(fileURL, destPath, authToken)
err := downloadFile(fileURL, destPath, cfg.authToken)
if err != nil && mandatory {
// If the file is mandatory, report an error
errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err)
Expand Down

0 comments on commit 53ff61b

Please sign in to comment.