Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add push force mode #8

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions cmd/modelx/model/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ limitations under the License.

package model

const (
ModelConfigFileName = "modelx.yaml"
ReadmeFileName = "README.md"
)

type ModelConfig struct {
Description string `json:"description"`
FrameWork string `json:"framework"`
Expand Down
6 changes: 4 additions & 2 deletions cmd/modelx/model/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (

"github.com/spf13/cobra"
"gopkg.in/yaml.v3"

"kubegems.io/modelx/pkg/client"
)

func NewInitCmd() *cobra.Command {
Expand Down Expand Up @@ -103,15 +105,15 @@ func InitModelx(ctx context.Context, path string, force bool) error {
if err != nil {
return fmt.Errorf("encode model %w", err)
}
configfile := filepath.Join(path, ModelConfigFileName)
configfile := filepath.Join(path, client.ModelConfigFileName)
if err := os.WriteFile(configfile, configcontent.Bytes(), 0o755); err != nil {
return fmt.Errorf("write model config:%s %w", configfile, err)
}

// Init README.md
basefile := filepath.Base(path)
if basefile != "" {
readmefile := filepath.Join(path, ReadmeFileName)
readmefile := filepath.Join(path, client.ReadmeFileName)
_, err := os.Stat(readmefile)
if errors.Is(err, os.ErrNotExist) {
readmecontent := fmt.Sprintf("# %s\n\nAwesome model descrition.\n", basefile)
Expand Down
18 changes: 12 additions & 6 deletions cmd/modelx/model/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ import (
)

func NewPullCmd() *cobra.Command {
IsForce := false
cmd := &cobra.Command{
Use: "pull",
Short: "pull a model from a repository",
Long: "pull [--force/-f] <repo>/[project]/[name]@[version] .",
Example: `
# Pull project/demo version latest to dirctory demo by default

modelx pull https://myrepo/project/demo
modelx pull myrepo/project/demo

# Pull project/demo to current dirctoty

modelx pull https://myrepo/project/demo@version .
modelx pull myrepo/project/demo@version .

# Pull project/demo to dirctoty abc

modelx pull https://myrepo/project/demo@version abc
modelx pull myrepo/project/demo@version abc

# Pull project/demo to dirctoty abc

modelx pull -f myrepo/project/demo@version abc
`,
SilenceUsage: true,
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
Expand All @@ -63,13 +68,14 @@ func NewPullCmd() *cobra.Command {
if len(args) == 1 {
args = append(args, "")
}
return PullModelx(ctx, args[0], args[1])
return PullModelx(ctx, args[0], args[1], IsForce)
},
}
cmd.Flags().BoolVarP(&IsForce, "force", "f", false, "force pull clean local modelx file or directory")
return cmd
}

func PullModelx(ctx context.Context, ref string, into string) error {
func PullModelx(ctx context.Context, ref string, into string, force bool) error {
reference, err := ParseReference(ref)
if err != nil {
return err
Expand All @@ -81,5 +87,5 @@ func PullModelx(ctx context.Context, ref string, into string) error {
into = path.Base(reference.Repository)
}
fmt.Printf("Pulling %s into %s \n", reference.String(), into)
return reference.Client().Pull(ctx, reference.Repository, reference.Version, into)
return reference.Client().Pull(ctx, reference.Repository, reference.Version, into, force)
}
9 changes: 5 additions & 4 deletions cmd/modelx/model/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"kubegems.io/modelx/cmd/modelx/repo"
"kubegems.io/modelx/pkg/client"
)

func NewPushCmd() *cobra.Command {
Expand Down Expand Up @@ -83,14 +84,14 @@ func PushModel(ctx context.Context, ref string, dir string) error {
dir = "."
}
// parse annotations from model config
configcontent, err := os.ReadFile(filepath.Join(dir, ModelConfigFileName))
configcontent, err := os.ReadFile(filepath.Join(dir, client.ModelConfigFileName))
if err != nil {
return fmt.Errorf("read model config:%s %w", ModelConfigFileName, err)
return fmt.Errorf("read model config:%s %w", client.ModelConfigFileName, err)
}
var config ModelConfig
if err := yaml.Unmarshal(configcontent, &config); err != nil {
return fmt.Errorf("parse model config:%s %w", ModelConfigFileName, err)
return fmt.Errorf("parse model config:%s %w", client.ModelConfigFileName, err)
}
fmt.Printf("Pushing to %s \n", reference.String())
return reference.Client().Push(ctx, reference.Repository, reference.Version, ModelConfigFileName, dir)
return reference.Client().Push(ctx, reference.Repository, reference.Version, client.ModelConfigFileName, dir)
}
29 changes: 14 additions & 15 deletions cmd/modelxd/modelxd.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,20 @@ func NewRegistryCmd() *cobra.Command {
}

flags := cmd.Flags()
flags.StringVar(&config.GlobalModelxdOptions.Listen, "listen", config.GlobalModelxdOptions.Listen, "listen address")
flags.StringVar(&config.GlobalModelxdOptions.TLS.CAFile, "tls-ca", config.GlobalModelxdOptions.TLS.CAFile, "tls ca file")
flags.StringVar(&config.GlobalModelxdOptions.TLS.CertFile, "tls-cert", config.GlobalModelxdOptions.TLS.CertFile, "tls cert file")
flags.StringVar(&config.GlobalModelxdOptions.TLS.KeyFile, "tls-key", config.GlobalModelxdOptions.TLS.KeyFile, "tls key file")
flags.StringVar(&config.GlobalModelxdOptions.S3.Buket, "s3-bucket", config.GlobalModelxdOptions.S3.Buket, "s3 bucket")
flags.StringVar(&config.GlobalModelxdOptions.S3.URL, "s3-url", config.GlobalModelxdOptions.S3.URL, "s3 url")
flags.StringVar(&config.GlobalModelxdOptions.S3.AccessKey, "s3-access-key", config.GlobalModelxdOptions.S3.AccessKey, "s3 access key")
flags.StringVar(&config.GlobalModelxdOptions.S3.SecretKey, "s3-secret-key", config.GlobalModelxdOptions.S3.SecretKey, "s3 secret key")
flags.DurationVar(&config.GlobalModelxdOptions.S3.PresignExpire, "s3-presign-expire", config.GlobalModelxdOptions.S3.PresignExpire, "s3 presign expire")
flags.StringVar(&config.GlobalModelxdOptions.S3.Region, "s3-region", config.GlobalModelxdOptions.S3.Region, "s3 region")
flags.StringVar(&config.GlobalModelxdOptions.OIDC.Issuer, "oidc-issuer", config.GlobalModelxdOptions.OIDC.Issuer, "oidc issuer")
flags.StringVar(&config.GlobalModelxdOptions.Local.Basepath, "path", config.GlobalModelxdOptions.Local.Basepath, "local metadate store path. Default: ./data/registry/")
flags.BoolVar(&config.GlobalModelxdOptions.EnableRedirect, "enable-redirect", config.GlobalModelxdOptions.EnableRedirect, "enable blob storage redirect. Default: false")
flags.BoolVar(&config.GlobalModelxdOptions.EnableMetrics, "enable-metrics", true, "enable metrics api. Default: true")
flags.StringVar(&config.GlobalModelxdOptions.Listen, "listen", config.GlobalModelxdOptions.Listen, "listen address.")
flags.StringVar(&config.GlobalModelxdOptions.TLS.CAFile, "tls-ca", config.GlobalModelxdOptions.TLS.CAFile, "tls ca file.")
flags.StringVar(&config.GlobalModelxdOptions.TLS.CertFile, "tls-cert", config.GlobalModelxdOptions.TLS.CertFile, "tls cert file.")
flags.StringVar(&config.GlobalModelxdOptions.TLS.KeyFile, "tls-key", config.GlobalModelxdOptions.TLS.KeyFile, "tls key file.")
flags.StringVar(&config.GlobalModelxdOptions.S3.Buket, "s3-bucket", config.GlobalModelxdOptions.S3.Buket, "s3 bucket.")
flags.StringVar(&config.GlobalModelxdOptions.S3.URL, "s3-url", config.GlobalModelxdOptions.S3.URL, "s3 url.")
flags.StringVar(&config.GlobalModelxdOptions.S3.AccessKey, "s3-access-key", config.GlobalModelxdOptions.S3.AccessKey, "s3 access key.")
flags.StringVar(&config.GlobalModelxdOptions.S3.SecretKey, "s3-secret-key", config.GlobalModelxdOptions.S3.SecretKey, "s3 secret key.")
flags.DurationVar(&config.GlobalModelxdOptions.S3.PresignExpire, "s3-presign-expire", config.GlobalModelxdOptions.S3.PresignExpire, "s3 presign expire.")
flags.StringVar(&config.GlobalModelxdOptions.S3.Region, "s3-region", config.GlobalModelxdOptions.S3.Region, "s3 region.")
flags.StringVar(&config.GlobalModelxdOptions.OIDC.Issuer, "oidc-issuer", config.GlobalModelxdOptions.OIDC.Issuer, "oidc issuer.")
flags.StringVar(&config.GlobalModelxdOptions.Local.Basepath, "path", config.GlobalModelxdOptions.Local.Basepath, "local metadate store path.")
flags.BoolVar(&config.GlobalModelxdOptions.EnableRedirect, "enable-redirect", false, "enable blob storage redirect.")
flags.BoolVar(&config.GlobalModelxdOptions.EnableMetrics, "enable-metrics", true, "enable metrics api.")

return cmd
}
Expand Down Expand Up @@ -157,7 +157,6 @@ func NewRegistryConfig(ctx context.Context, opt *config.Options) (*model.Registr
var registryStore registry.RegistryInterface
if registryStore == nil && opt.S3 != nil && opt.S3.URL != "" {
mainLogger.Info("start modelx registry with S3 type")

s3store, err := registry.NewS3RegistryStore(ctx, opt)
if err != nil {
return nil, err
Expand Down
6 changes: 5 additions & 1 deletion pkg/client/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,9 @@ const (
MediaTypeModelDirectoryTarGz = "application/vnd.modelx.model.directory.v1.tar+gz"

// default retry count
DefaultPullPushConcurrency = 3
DefaultPullPushConcurrency = 5

ModelConfigFileName = "modelx.yaml"
ReadmeFileName = "README.md"
ModelCacheDir = ".modelx"
)
39 changes: 32 additions & 7 deletions pkg/client/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ import (
"os"
"path/filepath"

"github.com/kubeservice-stack/common/pkg/utils"
"github.com/opencontainers/go-digest"
"golang.org/x/sync/errgroup"

"kubegems.io/modelx/pkg/progress"
"kubegems.io/modelx/pkg/response"
"kubegems.io/modelx/pkg/util"
)

func (c *Client) Pull(ctx context.Context, repo string, version string, into string) error {
func (c *Client) Pull(ctx context.Context, repo string, version string, into string, force bool) error {
// check if the directory exists and is empty
if dirInfo, err := os.Stat(into); err != nil {
if !os.IsNotExist(err) {
Expand All @@ -51,18 +53,38 @@ func (c *Client) Pull(ctx context.Context, repo string, version string, into str
if err != nil {
return err
}
return c.PullBlobs(ctx, repo, into, append(manifest.Blobs, manifest.Config))

blobs := append(manifest.Blobs, manifest.Config)
if force {
dirlists, err := utils.ListDir(into)
if err != nil {
return fmt.Errorf("force clean %s model fail, Please use pull model to other dirctoty.", into)
}

for _, dirlist := range dirlists {
if dirlist == ModelCacheDir || dirlist == ModelConfigFileName || dirlist == ReadmeFileName {
continue
}
flag := false
for _, blob := range blobs {
if dirlist == blob.Name {
flag = true
}
}
if !flag {
_ = utils.RemoveDir(filepath.Join(into, dirlist))
_ = utils.RemoveFile(filepath.Join(into, dirlist))
}
}
}

return c.PullBlobs(ctx, repo, into, blobs)
}

func (c *Client) PullBlobs(ctx context.Context, repo string, basedir string, blobs []util.Descriptor) error {
mb, ctx := progress.NewMuiltiBarContext(ctx, os.Stdout, 60, DefaultPullPushConcurrency)
for _, blob := range blobs {
mb.Go(blob.Name, "pending", func(b *progress.Bar) error {
if blob.MediaType == MediaTypeModelDirectoryTarGz {
if err := os.MkdirAll(filepath.Join(basedir, blob.Name), 0o755); err != nil {
return fmt.Errorf("create directory %s: %v", filepath.Join(basedir, blob.Name), err)
}
}
return c.pullBlobProgress(ctx, repo, blob, basedir, b)
})
}
Expand All @@ -72,6 +94,9 @@ func (c *Client) PullBlobs(ctx context.Context, repo string, basedir string, blo
func (c *Client) pullBlobProgress(ctx context.Context, repo string, desc util.Descriptor, basedir string, bar *progress.Bar) error {
switch desc.MediaType {
case MediaTypeModelDirectoryTarGz:
if err := os.MkdirAll(filepath.Join(basedir, desc.Name), 0o755); err != nil {
return fmt.Errorf("create directory %s: %v", filepath.Join(basedir, desc.Name), err)
}
return c.pullDirectory(ctx, repo, desc, basedir, bar, true)
case MediaTypeModelFile:
return c.pullFile(ctx, repo, desc, basedir, bar)
Expand Down
Loading