From 8732d30925b3d817401da40ae8e4bb90cd06e1cc Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:19:20 -0800 Subject: [PATCH] fix: concurrency issues in case of an error --- tokenizer.go | 48 ++++++++++++++++-------------------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/tokenizer.go b/tokenizer.go index 0e57950..c4ac143 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -131,37 +131,31 @@ func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, fileURL := fmt.Sprintf("%s/%s", modelURL, fn) destPath := filepath.Join(downloadDir, fn) err := downloadFile(fileURL, destPath, authToken) - if err != nil { - if mandatory { - // If the file is mandatory, report an error - errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) - } else { - // Optional files: log warning and continue - fmt.Printf("Warning: failed to download optional file %s: %v\n", fn, err) - } + if err != nil && mandatory { + // If the file is mandatory, report an error + errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err) } }(filename, isMandatory) } - // Wait for all downloads to complete - wg.Wait() - close(errCh) + go func() { + wg.Wait() + close(errCh) + }() - // Check for errors during downloads - for downloadErr := range errCh { - // Clean up the directory and return the error - cleanupDirectory(downloadDir) - return nil, downloadErr + var errs []error + for err := range errCh { + errs = append(errs, err) } - // Verify that tokenizer.json exists - tokenizerPath := filepath.Join(downloadDir, "tokenizer.json") - if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) { - return nil, fmt.Errorf("mandatory file tokenizer.json does not exist in %s", downloadDir) + if len(errs) > 0 { + if err := os.RemoveAll(downloadDir); err != nil { + fmt.Printf("Warning: failed to clean up directory %s: %v\n", downloadDir, err) + } + return nil, errs[0] } - // Initialize the tokenizer using the downloaded tokenizer.json - return FromFile(tokenizerPath) + return FromFile(filepath.Join(downloadDir, "tokenizer.json")) } // downloadFile downloads a file from the given URL and saves it to the specified destination. @@ -212,16 +206,6 @@ func downloadFile(url, destination string, authToken *string) error { return nil } -// cleanupDirectory removes the specified directory and all its contents. -func cleanupDirectory(dir string) { - err := os.RemoveAll(dir) - if err != nil { - fmt.Printf("Warning: failed to clean up directory %s: %v\n", dir, err) - } else { - fmt.Printf("Successfully cleaned up directory %s\n", dir) - } -} - func (t *Tokenizer) Close() error { C.free_tokenizer(t.tokenizer) t.tokenizer = nil