Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FromPretrained to load tokenizer directly from HF #27

Merged
merged 10 commits into from
Nov 7, 2024
44 changes: 43 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Go bindings for the [HuggingFace Tokenizers](https://github.com/huggingface/toke

## Installation

`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers.a'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers.a"` to avoid specifying it every time.
`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers/directory'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers/directory"` to avoid specifying it every time.

### Using pre-built binaries

Expand All @@ -31,6 +31,20 @@ if err != nil {
defer tk.Close()
```

Load a tokenizer from Huggingface:

```go
import "github.com/daulet/tokenizers"

tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased"
tk, err := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil)
if err != nil {
return err
}
// release native resources
defer tk.Close()
```

Encode text and decode tokens:

```go
Expand All @@ -44,6 +58,34 @@ fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true
// brown fox jumps over the lazy dog
```

Encode text with options:

```go
var encodeOptions []tokenizers.EncodeOption
encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs())
encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask())
encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens())
encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets())
encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask())

// Or just basically
// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes())

encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", false, encodeOptions...)
fmt.Println(encodingResponse.IDs)
// [2829 4419 14523 2058 1996 13971 3899]
fmt.Println(encodingResponse.TypeIDs)
// [0 0 0 0 0 0 0]
fmt.Println(encodingResponse.SpecialTokensMask)
// [0 0 0 0 0 0 0]
fmt.Println(encodingResponse.AttentionMask)
// [1 1 1 1 1 1 1]
fmt.Println(encodingResponse.Tokens)
// [brown fox jumps over the lazy dog]
fmt.Println(encodingResponse.Offsets)
// [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]]
```

## Benchmarks

`go test . -run=^\$ -bench=. -benchmem -count=10 > test/benchmark/$(git rev-parse HEAD).txt`
Expand Down
61 changes: 59 additions & 2 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ package main

import (
"fmt"
"log"

"github.com/daulet/tokenizers"
)

func main() {
func simple() error {
tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json")
if err != nil {
panic(err)
return err
}
// release native resources
defer tk.Close()

fmt.Println("Vocab size:", tk.VocabSize())
// Vocab size: 30522
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false))
Expand All @@ -21,4 +23,59 @@ func main() {
// [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true))
// brown fox jumps over the lazy dog
return nil
}

func advanced() error {
// Load tokenizer from local config file
tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json")
if err != nil {
return err
}
defer tk.Close()

// Load pretrained tokenizer from HuggingFace
tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", tokenizers.WithCacheDir("./.cache/tokenizers"))
if err != nil {
return err
}
defer tkFromHf.Close()

// Encode with specific options
encodeOptions := []tokenizers.EncodeOption{
tokenizers.WithReturnTypeIDs(),
tokenizers.WithReturnAttentionMask(),
tokenizers.WithReturnTokens(),
tokenizers.WithReturnOffsets(),
tokenizers.WithReturnSpecialTokensMask(),
}
// Or simply:
// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes())

// regardless of how the tokenizer was initialized, the output is the same
for _, tkzr := range []*tokenizers.Tokenizer{tk, tkFromHf} {
encodingResponse := tkzr.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...)
fmt.Println(encodingResponse.IDs)
// [101 2829 4419 14523 2058 1996 13971 3899 102]
fmt.Println(encodingResponse.TypeIDs)
// [0 0 0 0 0 0 0 0 0]
fmt.Println(encodingResponse.SpecialTokensMask)
// [1 0 0 0 0 0 0 0 1]
fmt.Println(encodingResponse.AttentionMask)
// [1 1 1 1 1 1 1 1 1]
fmt.Println(encodingResponse.Tokens)
// [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(encodingResponse.Offsets)
// [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]]
}
return nil
}

func main() {
if err := simple(); err != nil {
log.Fatal(err)
}
if err := advanced(); err != nil {
log.Fatal(err)
}
}
162 changes: 161 additions & 1 deletion tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,29 @@ import "C"
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"unsafe"
)

const WANT_VERSION = "1.20.2"
const (
WANT_VERSION = "1.20.2"

baseURL = "https://huggingface.co"
)

// List of necessary tokenizer files and their mandatory status.
// True means mandatory, false means optional.
var tokenizerFiles = map[string]bool{
"tokenizer.json": true,
"vocab.txt": false,
"merges.txt": false,
"special_tokens_map.json": false,
"added_tokens.json": false,
}

func init() {
version := C.version()
Expand Down Expand Up @@ -78,6 +97,147 @@ func FromFile(path string) (*Tokenizer, error) {
return &Tokenizer{tokenizer: tokenizer}, nil
}

type tokenizerConfig struct {
cacheDir *string
authToken *string
}

type TokenizerConfigOption func(cfg *tokenizerConfig)

func WithCacheDir(path string) TokenizerConfigOption {
return func(cfg *tokenizerConfig) {
cfg.cacheDir = &path
}
}

func WithAuthToken(token string) TokenizerConfigOption {
return func(cfg *tokenizerConfig) {
cfg.authToken = &token
}
}

// FromPretrained downloads necessary files and initializes the tokenizer.
// Parameters:
// - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased").
// - destination: Optional. If provided and not nil, files will be downloaded to this folder.
// If nil, a temporary directory will be used.
// - authToken: Optional. If provided and not nil, it will be used to authenticate requests.
func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, error) {
cfg := &tokenizerConfig{}
for _, opt := range opts {
opt(cfg)
}
if strings.TrimSpace(modelID) == "" {
return nil, fmt.Errorf("modelID cannot be empty")
}

// Construct the model URL
modelURL := fmt.Sprintf("%s/%s/resolve/main", baseURL, modelID)

// Determine the download directory
var downloadDir string
if cfg.cacheDir != nil {
downloadDir = fmt.Sprintf("%s/%s", *cfg.cacheDir, modelID)
// Create the destination directory if it doesn't exist
err := os.MkdirAll(downloadDir, os.ModePerm)
if err != nil {
return nil, fmt.Errorf("failed to create destination directory %s: %w", downloadDir, err)
}
} else {
// Create a temporary directory
tmpDir, err := os.MkdirTemp("", "huggingface-tokenizer-*")
if err != nil {
return nil, fmt.Errorf("error creating temporary directory: %w", err)
}
downloadDir = tmpDir
}

var wg sync.WaitGroup
errCh := make(chan error)

// Download each tokenizer file concurrently
for filename, isMandatory := range tokenizerFiles {
wg.Add(1)
go func(fn string, mandatory bool) {
defer wg.Done()
fileURL := fmt.Sprintf("%s/%s", modelURL, fn)
destPath := filepath.Join(downloadDir, fn)
err := downloadFile(fileURL, destPath, cfg.authToken)
if err != nil && mandatory {
// If the file is mandatory, report an error
errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err)
}
}(filename, isMandatory)
}

go func() {
wg.Wait()
close(errCh)
}()

var errs []error
for err := range errCh {
errs = append(errs, err)
}

if len(errs) > 0 {
if err := os.RemoveAll(downloadDir); err != nil {
fmt.Printf("Warning: failed to clean up directory %s: %v\n", downloadDir, err)
}
return nil, errs[0]
}

return FromFile(filepath.Join(downloadDir, "tokenizer.json"))
}

// downloadFile downloads a file from the given URL and saves it to the specified destination.
// If authToken is provided (non-nil), it will be used for authorization.
// Returns an error if the download fails.
func downloadFile(url, destination string, authToken *string) error {
// Check if the file already exists
if _, err := os.Stat(destination); err == nil {
return nil
}

// Create a new HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %s: %w", url, err)
}

// If authToken is provided, set the Authorization header
if authToken != nil {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *authToken))
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download from %s: %w", url, err)
}
defer resp.Body.Close()

// Check for successful response
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download from %s: status code %d", url, resp.StatusCode)
}

// Create the destination file
out, err := os.Create(destination)
if err != nil {
return fmt.Errorf("failed to create file %s: %w", destination, err)
}
defer out.Close()

// Write the response body to the file
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write to file %s: %w", destination, err)
}

fmt.Printf("Successfully downloaded %s\n", destination)
return nil
}

func (t *Tokenizer) Close() error {
C.free_tokenizer(t.tokenizer)
t.tokenizer = nil
Expand Down
Loading
Loading