Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openapi: require name for contexts #253

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1023,10 +1023,18 @@ components:
- $ref: "#/components/schemas/BaseResourceList"
RegisteredModelCreate:
description: A registered model in model registry. A registered model has ModelVersion children.
required:
- name
allOf:
- type: object
- $ref: "#/components/schemas/BaseResourceCreate"
- $ref: "#/components/schemas/RegisteredModelUpdate"
properties:
name:
description: |-
The client provided name of the model. It must be unique among all the RegisteredModels of the same
type within a Model Registry instance and cannot be changed once set.
type: string
RegisteredModelUpdate:
description: A registered model in model registry. A registered model has ModelVersion children.
allOf:
Expand All @@ -1045,6 +1053,7 @@ components:
ModelVersionCreate:
description: Represents a ModelVersion belonging to a RegisteredModel.
required:
- name
- registeredModelId
allOf:
- $ref: "#/components/schemas/BaseResourceCreate"
Expand All @@ -1054,6 +1063,11 @@ components:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
name:
description: |-
The client provided name of the model's version. It must be unique among all the ModelVersions of the same
type within a Model Registry instance and cannot be changed once set.
type: string
ModelVersionUpdate:
description: Represents a ModelVersion belonging to a RegisteredModel.
allOf:
Expand Down
5 changes: 2 additions & 3 deletions internal/converter/generated/mlmd_openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 10 additions & 18 deletions internal/converter/generated/openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/converter/generated/openapi_mlmd_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions internal/converter/mlmd_converter_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestMapModelVersionProperties(t *testing.T) {
ParentResourceId: of("123"),
ModelName: of("MyModel"),
Model: &openapi.ModelVersion{
Name: of("v1"),
Name: "v1",
Description: of("my model version description"),
Author: of("John Doe"),
},
Expand Down Expand Up @@ -244,7 +244,7 @@ func TestMapModelVersionName(t *testing.T) {
ParentResourceId: of("123"),
ModelName: of("MyModel"),
Model: &openapi.ModelVersion{
Name: of("v1"),
Name: "v1",
},
})
assertion.NotNil(name)
Expand Down
2 changes: 1 addition & 1 deletion internal/converter/mlmd_openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type MLMDToOpenAPIConverter interface {
// goverter:map Properties State | MapRegisteredModelState
ConvertRegisteredModel(source *proto.Context) (*openapi.RegisteredModel, error)

// goverter:map Name | MapNameFromOwned
// goverter:map Name | MapName
// goverter:map Name RegisteredModelId | MapRegisteredModelIdFromOwned
// goverter:map Properties Description | MapDescription
// goverter:map Properties State | MapModelVersionState
Expand Down
14 changes: 14 additions & 0 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ func MapNameFromOwned(source *string) *string {
return &exploded[1]
}

// MapName derive the entity name from the mlmd fullname
// for owned entity such as ModelVersion
func MapName(source *string) string {
if source == nil {
return ""
}

exploded := strings.Split(*source, ":")
if len(exploded) == 1 {
return *source
}
return exploded[1]
}

// REGISTERED MODEL

// MODEL VERSION
Expand Down
6 changes: 3 additions & 3 deletions internal/converter/openapi_mlmd_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import (
type OpenAPIModelWrapper[
M OpenAPIModel,
] struct {
TypeId int64
Model *M
ParentResourceId *string // optional parent id
ModelName *string // optional registered model name
ParentResourceId *string
ModelName *string
TypeId int64
}

// goverter:converter
Expand Down
4 changes: 2 additions & 2 deletions internal/converter/openapi_mlmd_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func MapModelVersionProperties(source *OpenAPIModelWrapper[openapi.ModelVersion]
}
props["version"] = &proto.Value{
Value: &proto.Value_StringValue{
StringValue: *(*source.Model).Name,
StringValue: (*source.Model).Name,
},
}

Expand Down Expand Up @@ -208,7 +208,7 @@ func MapModelVersionType(_ *openapi.ModelVersion) *string {
// MapModelVersionName maps the user-provided name into MLMD one, i.e., prefixing it with
// either the parent resource id or a generated uuid
func MapModelVersionName(source *OpenAPIModelWrapper[openapi.ModelVersion]) *string {
return of(PrefixWhenOwned(source.ParentResourceId, *(*source).Model.Name))
return of(PrefixWhenOwned(source.ParentResourceId, (*source).Model.Name))
}

// ARTIFACT
Expand Down
9 changes: 7 additions & 2 deletions internal/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ func NewMapper(mlmdTypes map[string]int64) *Mapper {
}
}

// of returns a pointer to the provided literal/const input
func of[E any](e E) *E {
return &e
}

// Utilities for OpenAPI --> MLMD mapping, make use of generated Converters

func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
Expand All @@ -33,12 +38,12 @@ func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel
})
}

func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registeredModelId string, registeredModelName *string) (*proto.Context, error) {
func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registeredModelId string, registeredModelName string) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertModelVersion(&converter.OpenAPIModelWrapper[openapi.ModelVersion]{
TypeId: m.MLMDTypes[defaults.ModelVersionTypeName],
Model: modelVersion,
ParentResourceId: &registeredModelId,
ModelName: registeredModelName,
ModelName: of(registeredModelName),
})
}

Expand Down
11 changes: 3 additions & 8 deletions internal/mapper/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func setup(t *testing.T) (*assert.Assertions, *Mapper) {
func TestMapFromRegisteredModel(t *testing.T) {
assertion, m := setup(t)

ctx, err := m.MapFromRegisteredModel(&openapi.RegisteredModel{Name: of("ModelName")})
ctx, err := m.MapFromRegisteredModel(&openapi.RegisteredModel{Name: "ModelName"})
assertion.Nil(err)
assertion.Equal("ModelName", ctx.GetName())
assertion.Equal(registeredModelTypeId, ctx.GetTypeId())
Expand All @@ -49,7 +49,7 @@ func TestMapFromRegisteredModel(t *testing.T) {
func TestMapFromModelVersion(t *testing.T) {
assertion, m := setup(t)

ctx, err := m.MapFromModelVersion(&openapi.ModelVersion{Name: of("v1")}, "1", of("ModelName"))
ctx, err := m.MapFromModelVersion(&openapi.ModelVersion{Name: "v1"}, "1", "ModelName")
assertion.Nil(err)
assertion.Equal("1:v1", ctx.GetName())
assertion.Equal(modelVersionTypeId, ctx.GetTypeId())
Expand Down Expand Up @@ -277,12 +277,7 @@ func TestMapToServeModelInvalid(t *testing.T) {
}

func TestMapTo(t *testing.T) {
_, err := mapTo[*proto.Execution, any](&proto.Execution{TypeId: of(registeredModelTypeId)}, typesMap, "notExisitingTypeName", func(e *proto.Execution) (*any, error) { return nil, nil })
_, err := mapTo(&proto.Execution{TypeId: of(registeredModelTypeId)}, typesMap, "notExisitingTypeName", func(e *proto.Execution) (*any, error) { return nil, nil })
assert.NotNil(t, err)
assert.Equal(t, "unknown type name provided: notExisitingTypeName", err.Error())
}

// of returns a pointer to the provided literal/const input
func of[E any](e E) *E {
return &e
}
2 changes: 1 addition & 1 deletion internal/server/openapi/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ package openapi

// ImplResponse defines an implementation response with error code and the associated body
type ImplResponse struct {
Code int
Body interface{}
Code int
}
20 changes: 20 additions & 0 deletions internal/server/openapi/type_asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ func AssertModelArtifactUpdateConstraints(obj model.ModelArtifactUpdate) error {
// AssertModelVersionRequired checks if the required fields are not zero-ed
func AssertModelVersionRequired(obj model.ModelVersion) error {
elements := map[string]interface{}{
"name": obj.Name,
"registeredModelId": obj.RegisteredModelId,
}
for name, el := range elements {
Expand All @@ -543,6 +544,7 @@ func AssertModelVersionConstraints(obj model.ModelVersion) error {
// AssertModelVersionCreateRequired checks if the required fields are not zero-ed
func AssertModelVersionCreateRequired(obj model.ModelVersionCreate) error {
elements := map[string]interface{}{
"name": obj.Name,
"registeredModelId": obj.RegisteredModelId,
}
for name, el := range elements {
Expand Down Expand Up @@ -617,6 +619,15 @@ func AssertOrderByFieldConstraints(obj model.OrderByField) error {

// AssertRegisteredModelRequired checks if the required fields are not zero-ed
func AssertRegisteredModelRequired(obj model.RegisteredModel) error {
elements := map[string]interface{}{
"name": obj.Name,
}
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {
return &RequiredError{Field: name}
}
}

return nil
}

Expand All @@ -627,6 +638,15 @@ func AssertRegisteredModelConstraints(obj model.RegisteredModel) error {

// AssertRegisteredModelCreateRequired checks if the required fields are not zero-ed
func AssertRegisteredModelCreateRequired(obj model.RegisteredModelCreate) error {
elements := map[string]interface{}{
"name": obj.Name,
}
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {
return &RequiredError{Field: name}
}
}

return nil
}

Expand Down
Loading
Loading