Skip to content

Commit

Permalink
fix: concurrency issues in case of an error
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Nov 5, 2024
1 parent 7de9a11 commit 8732d30
Showing 1 changed file with 16 additions and 32 deletions.
48 changes: 16 additions & 32 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,37 +131,31 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer,
fileURL := fmt.Sprintf("%s/%s", modelURL, fn)
destPath := filepath.Join(downloadDir, fn)
err := downloadFile(fileURL, destPath, authToken)
if err != nil {
if mandatory {
// If the file is mandatory, report an error
errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err)
} else {
// Optional files: log warning and continue
fmt.Printf("Warning: failed to download optional file %s: %v\n", fn, err)
}
if err != nil && mandatory {
// If the file is mandatory, report an error
errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err)
}
}(filename, isMandatory)
}

// Wait for all downloads to complete
wg.Wait()
close(errCh)
go func() {
wg.Wait()
close(errCh)
}()

// Check for errors during downloads
for downloadErr := range errCh {
// Clean up the directory and return the error
cleanupDirectory(downloadDir)
return nil, downloadErr
var errs []error
for err := range errCh {
errs = append(errs, err)
}

// Verify that tokenizer.json exists
tokenizerPath := filepath.Join(downloadDir, "tokenizer.json")
if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) {
return nil, fmt.Errorf("mandatory file tokenizer.json does not exist in %s", downloadDir)
if len(errs) > 0 {
if err := os.RemoveAll(downloadDir); err != nil {
fmt.Printf("Warning: failed to clean up directory %s: %v\n", downloadDir, err)
}
return nil, errs[0]
}

// Initialize the tokenizer using the downloaded tokenizer.json
return FromFile(tokenizerPath)
return FromFile(filepath.Join(downloadDir, "tokenizer.json"))
}

// downloadFile downloads a file from the given URL and saves it to the specified destination.
Expand Down Expand Up @@ -212,16 +206,6 @@ func downloadFile(url, destination string, authToken *string) error {
return nil
}

// cleanupDirectory removes the specified directory and all its contents.
func cleanupDirectory(dir string) {
err := os.RemoveAll(dir)
if err != nil {
fmt.Printf("Warning: failed to clean up directory %s: %v\n", dir, err)
} else {
fmt.Printf("Successfully cleaned up directory %s\n", dir)
}
}

func (t *Tokenizer) Close() error {
C.free_tokenizer(t.tokenizer)
t.tokenizer = nil
Expand Down

0 comments on commit 8732d30

Please sign in to comment.