Skip to content

Commit

Permalink
feat(csi): support multiple model registries
Browse files Browse the repository at this point in the history
Signed-off-by: Alessio Pragliola <[email protected]>
  • Loading branch information
Al-Pragliola committed Oct 23, 2024
1 parent 505b00e commit 5ad3b2b
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 114 deletions.
1 change: 1 addition & 0 deletions .github/workflows/csi-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
uses: helm/[email protected]
with:
node_image: "kindest/node:v1.27.11"
config: ./csi/test/kind_config.yaml

- name: Install kustomize
run: ./csi/scripts/install_kustomize.sh
Expand Down
6 changes: 5 additions & 1 deletion csi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"
"os"

"github.com/kubeflow/model-registry/csi/pkg/modelregistry"
"github.com/kubeflow/model-registry/csi/pkg/storage"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -38,7 +39,10 @@ func main() {
cfg := openapi.NewConfiguration()
cfg.Host = baseUrl
cfg.Scheme = scheme
provider, err := storage.NewModelRegistryProvider(cfg)

apiClient := modelregistry.NewAPIClient(cfg, sourceUri)

provider, err := storage.NewModelRegistryProvider(apiClient)
if err != nil {
log.Fatalf("Error initiliazing model registry provider: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions csi/pkg/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package constants

import kserve "github.com/kserve/kserve/pkg/agent/storage"

const MR kserve.Protocol = "model-registry://"
41 changes: 41 additions & 0 deletions csi/pkg/modelregistry/api_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package modelregistry

import (
"context"
"log"
"strings"

"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient {
client := openapi.NewAPIClient(cfg)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) < 2 {
return client
}

newCfg := openapi.NewConfiguration()
newCfg.Host = tokens[0]
newCfg.Scheme = "http"

newClient := openapi.NewAPIClient(newCfg)

if len(tokens) == 2 {
// Check if the model registry service is available
_, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute()
if err != nil {
log.Printf("Falling back to base url %s for model registry service", cfg.Host)

return client
}
}

return newClient
}
148 changes: 100 additions & 48 deletions csi/pkg/storage/modelregistry_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,52 @@ package storage

import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strings"

kserve "github.com/kserve/kserve/pkg/agent/storage"
"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

const MR kserve.Protocol = "model-registry://"
var (
_ kserve.Provider = (*ModelRegistryProvider)(nil)
ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}")
ErrNoVersionAssociated = errors.New("no versions associated to registered model")
ErrNoArtifactAssociated = errors.New("no artifacts associated to model version")
ErrNoModelArtifact = errors.New("no model artifact found for model version")
ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI")
ErrNoStorageURI = errors.New("there is no storageUri supplied")
ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri")
ErrProtocolNotSupported = errors.New("protocol not supported for storageUri")
)

type ModelRegistryProvider struct {
Client *openapi.APIClient
Providers map[kserve.Protocol]kserve.Provider
}

func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) {
client := openapi.NewAPIClient(cfg)

func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) {
return &ModelRegistryProvider{
Client: client,
Providers: map[kserve.Protocol]kserve.Provider{},
}, nil
}

var _ kserve.Provider = (*ModelRegistryProvider)(nil)

// storageUri formatted like model-registry://{registeredModelName}/{versionName}
// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName}
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(MR))
tokens := strings.SplitN(mrUri, "/", 2)

if len(tokens) == 0 || len(tokens) > 2 {
return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}")
}
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s",
modelName,
storageUri,
modelDir,
)

registeredModelName := tokens[0]
var versionName *string
if len(tokens) == 2 {
versionName = &tokens[1]
registeredModelName, versionName, err := p.parseModelVersion(storageUri)
if err != nil {
return err
}

// Fetch the registered model
Expand All @@ -54,25 +57,9 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
}

// Fetch model version by name or latest if not specified
var version *openapi.ModelVersion
if versionName != nil {
version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute()
if err != nil {
return err
}
} else {
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return err
}

if versions.Size == 0 {
return fmt.Errorf("no versions associated to registered model %s", registeredModelName)
}
version = &versions.Items[0]
version, err := p.fetchModelVersion(versionName, registeredModelName, model)
if err != nil {
return err
}

artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id).
Expand All @@ -84,20 +71,20 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
}

if artifacts.Size == 0 {
return fmt.Errorf("no artifacts associated to model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id)
}

modelArtifact := artifacts.Items[0].ModelArtifact
if modelArtifact == nil {
return fmt.Errorf("no model artifact found for model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id)
}

// Call appropriate kserve provider based on the indexed model artifact URI
if modelArtifact.Uri == nil {
return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id)
return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id)
}

protocol, err := extractProtocol(*modelArtifact.Uri)
protocol, err := p.extractProtocol(*modelArtifact.Uri)
if err != nil {
return err
}
Expand All @@ -110,19 +97,84 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri)
}

func extractProtocol(storageURI string) (kserve.Protocol, error) {
// Possible URIs:
// (1) model-registry://{modelName}
// (2) model-registry://{modelName}/{modelVersion}
// (3) model-registry://{modelRegistryUrl}/{modelName}
// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion}
func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) {
var versionName *string

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) == 0 || len(tokens) > 3 {
return "", nil, ErrInvalidMRURI
}

// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
tokens = tokens[1:]
}

registeredModelName := tokens[0]

if len(tokens) == 2 {
versionName = &tokens[1]
}

return registeredModelName, versionName, nil
}

func (p *ModelRegistryProvider) fetchModelVersion(
versionName *string,
registeredModelName string,
model *openapi.RegisteredModel,
) (*openapi.ModelVersion, error) {
if versionName != nil {
version, _, err := p.Client.ModelRegistryServiceAPI.
FindModelVersion(context.Background()).
Name(*versionName).
ParentResourceId(*model.Id).
Execute()
if err != nil {
return nil, err
}

return version, nil
}

versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return nil, err
}

if versions.Size == 0 {
return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName)
}

return &versions.Items[0], nil
}

func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) {
if storageURI == "" {
return "", fmt.Errorf("there is no storageUri supplied")
return "", ErrNoStorageURI
}

if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
return "", fmt.Errorf("there is no protocol specified for the storageUri")
if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) {
return "", ErrNoProtocolInSTorageURI
}

for _, prefix := range kserve.SupportedProtocols {
if strings.HasPrefix(storageURI, string(prefix)) {
return prefix, nil
}
}
return "", fmt.Errorf("protocol not supported for storageUri")

return "", ErrProtocolNotSupported
}
8 changes: 5 additions & 3 deletions csi/scripts/install_modelregistry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ if ! kubectl get namespace "$namespace" &> /dev/null; then
fi
# Apply model-registry kustomize manifests
echo Using model registry image: $image
cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && cd -
cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && \
kustomize edit set namespace $namespace && cd -
cd $MR_ROOT/manifests/kustomize/overlays/db && kustomize edit set namespace $namespace && cd -
kubectl -n $namespace apply -k "$MR_ROOT/manifests/kustomize/overlays/db"

# Wait for model registry deployment
modelregistry=$(kubectl get pod -n kubeflow --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
modelregistry=$(kubectl get pod -n $namespace --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
Loading

0 comments on commit 5ad3b2b

Please sign in to comment.