Skip to content

Commit

Permalink
Merge pull request #13 from InfuseAI/feature/sc-24993/local-repo-enha…
Browse files Browse the repository at this point in the history
…ncement

Local repo support atomic upload
  • Loading branch information
popcornylu authored Mar 10, 2022
2 parents febfe3b + e9ed141 commit c74e3d2
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 8 deletions.
14 changes: 12 additions & 2 deletions cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,18 @@ var configCommand = &cobra.Command{
return
}

_, err := repository.NewRepository(value)
exitWithError(err)
cwd, _ := os.Getwd()
repo, err := transformRepoUrl(cwd, value)
if err != nil {
exitWithError(err)
return
}

_, err = repository.NewRepository(repo)
if err != nil {
exitWithError(err)
return
}
}

config.Set(key, value)
Expand Down
8 changes: 6 additions & 2 deletions cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ var initCommand = &cobra.Command{
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
cwd, _ := os.Getwd()
repo := args[0]
repo, err := transformRepoUrl(cwd, args[0])
if err != nil {
exitWithError(err)
return
}

if strings.HasPrefix(repo, "http") {
exitWithError(errors.New("init not support under http(s) repo"))
return
}

_, err := repository.NewRepository(repo)
_, err = repository.NewRepository(repo)
if err != nil {
exitWithError(err)
return
Expand Down
33 changes: 33 additions & 0 deletions cmd/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package cmd

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestTransformRepoUrl(t *testing.T) {
baseDir := "/tmp/artiv"
testCases := []struct {
desc string
in string
out string
}{
{desc: "local file", in: "/this/is/my/path", out: "/this/is/my/path"},
{desc: "relative path", in: "../path", out: "/tmp/path"},
{desc: "relative path2", in: "../../../path", out: "/path"},
{desc: "normal url (file)", in: "file://mybucket/this/is/my/path", out: "file://mybucket/this/is/my/path"},
{desc: "normal url (s3)", in: "s3://mybucket/this/is/my/path", out: "s3://mybucket/this/is/my/path"},
}

for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
result, err := transformRepoUrl(baseDir, tC.in)
if err != nil {
assert.Empty(t, tC.out)
} else {
assert.Equal(t, tC.out, result)
}
})
}
}
19 changes: 19 additions & 0 deletions cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cmd
import (
"errors"
"fmt"
neturl "net/url"
"path/filepath"
"strings"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -33,3 +35,20 @@ func parseRepoStr(repoAndRef string) (repoUrl string, ref string, err error) {
}
return
}

func transformRepoUrl(base string, repo string) (string, error) {
url, err := neturl.Parse(repo)
if err != nil {
return "", err
}

if url.Scheme != "" {
return repo, nil
}

if strings.HasPrefix(repo, "/") {
return repo, nil
}

return filepath.Abs(filepath.Join(base, url.Path))
}
34 changes: 30 additions & 4 deletions internal/repository/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,45 @@ func (repo *LocalFileSystemRepository) Upload(localPath, repoPath string, m *met
}
defer source.Close()

// Copy from source to tmp
tmpDir := path.Join(repo.RepoDir, "tmp")
err = os.MkdirAll(tmpDir, fs.ModePerm)
if err != nil {
return err
}

tmp, err := os.CreateTemp(tmpDir, "*")
if err != nil {
return err
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
_, err = meter.CopyWithMeter(tmp, source, m)
if err != nil {
return err
}
err = tmp.Close()
if err != nil {
return err
}

// Move from tmp to dest
destPath := path.Join(repo.RepoDir, repoPath)
err = os.MkdirAll(filepath.Dir(destPath), fs.ModePerm)
if err != nil {
return err
}
err = os.Remove(destPath)
if err != nil && !os.IsNotExist(err) {
return err
}

destination, err := os.Create(destPath)
err = os.Rename(tmpPath, destPath)
if err != nil {
return err
}
defer destination.Close()
_, err = meter.CopyWithMeter(destination, source, m)
return err

return nil
}

func (repo *LocalFileSystemRepository) Download(repoPath, localPath string, m *meter.Meter) error {
Expand Down
93 changes: 93 additions & 0 deletions internal/repository/local_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package repository

import (
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestLocalUpload(t *testing.T) {
testCases := []struct {
desc string
data string
}{
{
desc: "empty file", data: "",
},
{
desc: "non empty file", data: "hello",
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
repoDir := t.TempDir()
tmpDir := t.TempDir()

repo, err := NewLocalFileSystemRepository(repoDir)
if err != nil {
t.Error(err)
}

err = os.WriteFile(tmpDir+"/test", []byte(tC.data), 0644)
if err != nil {
t.Error(err)
}

err = repo.Upload(tmpDir+"/test", "path/to/the/test", nil)
if err != nil {
t.Error(err)
}
data, err := os.ReadFile(repoDir + "/path/to/the/test")
if err != nil {
t.Error(err)
}
assert.Equal(t, []byte(tC.data), []byte(data))
})
}
}

func TestLocalDownload(t *testing.T) {
testCases := []struct {
desc string
data string
}{
{
desc: "empty file", data: "",
},
{
desc: "non empty file", data: "hello",
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
repoDir := t.TempDir()
tmpDir := t.TempDir()

repo, err := NewLocalFileSystemRepository(repoDir)
if err != nil {
t.Error(err)
}

err = os.MkdirAll(repoDir+"/path/to/the", os.ModePerm)
if err != nil {
t.Error(err)
}

err = os.WriteFile(repoDir+"/path/to/the/test", []byte(tC.data), 0644)
if err != nil {
t.Error(err)
}

err = repo.Download("path/to/the/test", tmpDir+"/test", nil)
if err != nil {
t.Error(err)
}
data, err := os.ReadFile(tmpDir + "/test")
if err != nil {
t.Error(err)
}
assert.Equal(t, []byte(tC.data), []byte(data))
})
}
}

0 comments on commit c74e3d2

Please sign in to comment.