From bdbc3eed8b7770db8a4bc5bda321dc851281ed81 Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Mon, 5 Aug 2024 04:53:38 -0300 Subject: [PATCH] wrap default openapi error handler (#235) Signed-off-by: Isabella do Amaral --- .openapi-generator-ignore | 2 + .../api_model_registry_service_service.go | 176 +++++++----------- internal/server/openapi/error.go | 12 +- internal/server/openapi/helpers.go | 9 + test/robot/Regression.robot | 16 ++ 5 files changed, 96 insertions(+), 119 deletions(-) create mode 100644 test/robot/Regression.robot diff --git a/.openapi-generator-ignore b/.openapi-generator-ignore index 18a46d80e..46a8017ea 100644 --- a/.openapi-generator-ignore +++ b/.openapi-generator-ignore @@ -45,3 +45,5 @@ internal/server/openapi/.openapi-generator-ignore internal/server/openapi/api_model_registry_service_service.go internal/server/openapi/README.md internal/server/openapi/main.go +internal/server/openapi/error.go +internal/server/openapi/helpers.go diff --git a/internal/server/openapi/api_model_registry_service_service.go b/internal/server/openapi/api_model_registry_service_service.go index 5ec322fa7..5e44fae6d 100644 --- a/internal/server/openapi/api_model_registry_service_service.go +++ b/internal/server/openapi/api_model_registry_service_service.go @@ -49,13 +49,12 @@ func (s *ModelRegistryServiceAPIService) CreateEnvironmentInferenceService(ctx c func (s *ModelRegistryServiceAPIService) CreateInferenceService(ctx context.Context, inferenceServiceCreate model.InferenceServiceCreate) (ImplResponse, error) { entity, err := s.converter.ConvertInferenceServiceCreate(&inferenceServiceCreate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertInferenceService(entity) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -65,13 +64,12 @@ func (s *ModelRegistryServiceAPIService) CreateInferenceService(ctx context.Cont func (s *ModelRegistryServiceAPIService) CreateInferenceServiceServe(ctx context.Context, inferenceserviceId string, serveModelCreate model.ServeModelCreate) (ImplResponse, error) { entity, err := s.converter.ConvertServeModelCreate(&serveModelCreate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertServeModel(entity, &inferenceserviceId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -81,13 +79,12 @@ func (s *ModelRegistryServiceAPIService) CreateInferenceServiceServe(ctx context func (s *ModelRegistryServiceAPIService) CreateModelArtifact(ctx context.Context, modelArtifactCreate model.ModelArtifactCreate) (ImplResponse, error) { entity, err := s.converter.ConvertModelArtifactCreate(&modelArtifactCreate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertModelArtifact(entity, nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -97,13 +94,12 @@ func (s *ModelRegistryServiceAPIService) CreateModelArtifact(ctx context.Context func (s *ModelRegistryServiceAPIService) CreateModelVersion(ctx context.Context, modelVersionCreate model.ModelVersionCreate) (ImplResponse, error) { modelVersion, err := s.converter.ConvertModelVersionCreate(&modelVersionCreate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertModelVersion(modelVersion, &modelVersionCreate.RegisteredModelId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -113,8 +109,7 @@ func (s *ModelRegistryServiceAPIService) CreateModelVersion(ctx context.Context, func (s *ModelRegistryServiceAPIService) CreateModelVersionArtifact(ctx context.Context, modelversionId string, artifact model.Artifact) (ImplResponse, error) { result, err := s.coreApi.UpsertArtifact(&artifact, &modelversionId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // return Response(http.StatusNotImplemented, nil), errors.New("unsupported artifactType") @@ -126,13 +121,12 @@ func (s *ModelRegistryServiceAPIService) CreateModelVersionArtifact(ctx context. func (s *ModelRegistryServiceAPIService) CreateRegisteredModel(ctx context.Context, registeredModelCreate model.RegisteredModelCreate) (ImplResponse, error) { registeredModel, err := s.converter.ConvertRegisteredModelCreate(®isteredModelCreate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertRegisteredModel(registeredModel) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -142,8 +136,7 @@ func (s *ModelRegistryServiceAPIService) CreateRegisteredModel(ctx context.Conte func (s *ModelRegistryServiceAPIService) CreateRegisteredModelVersion(ctx context.Context, registeredmodelId string, modelVersion model.ModelVersion) (ImplResponse, error) { result, err := s.coreApi.UpsertModelVersion(&modelVersion, apiutils.StrPtr(registeredmodelId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -153,13 +146,12 @@ func (s *ModelRegistryServiceAPIService) CreateRegisteredModelVersion(ctx contex func (s *ModelRegistryServiceAPIService) CreateServingEnvironment(ctx context.Context, servingEnvironmentCreate model.ServingEnvironmentCreate) (ImplResponse, error) { entity, err := s.converter.ConvertServingEnvironmentCreate(&servingEnvironmentCreate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertServingEnvironment(entity) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusCreated, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -169,8 +161,7 @@ func (s *ModelRegistryServiceAPIService) CreateServingEnvironment(ctx context.Co func (s *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { result, err := s.coreApi.GetInferenceServiceByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -180,8 +171,7 @@ func (s *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Contex func (s *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { result, err := s.coreApi.GetModelArtifactByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -191,8 +181,7 @@ func (s *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context, func (s *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context, name string, externalId string, registeredModelId string) (ImplResponse, error) { result, err := s.coreApi.GetModelVersionByParams(apiutils.StrPtr(name), apiutils.StrPtr(registeredModelId), apiutils.StrPtr(externalId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -202,8 +191,7 @@ func (s *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context, n func (s *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context, name string, externalID string) (ImplResponse, error) { result, err := s.coreApi.GetRegisteredModelByParams(apiutils.StrPtr(name), apiutils.StrPtr(externalID)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -213,8 +201,7 @@ func (s *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context func (s *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Context, name string, externalID string) (ImplResponse, error) { result, err := s.coreApi.GetServingEnvironmentByParams(apiutils.StrPtr(name), apiutils.StrPtr(externalID)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -224,13 +211,11 @@ func (s *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Cont func (s *ModelRegistryServiceAPIService) GetEnvironmentInferenceServices(ctx context.Context, servingenvironmentId string, name string, externalID string, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetInferenceServices(listOpts, apiutils.StrPtr(servingenvironmentId), nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -240,8 +225,7 @@ func (s *ModelRegistryServiceAPIService) GetEnvironmentInferenceServices(ctx con func (s *ModelRegistryServiceAPIService) GetInferenceService(ctx context.Context, inferenceserviceId string) (ImplResponse, error) { result, err := s.coreApi.GetInferenceServiceById(inferenceserviceId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -251,8 +235,7 @@ func (s *ModelRegistryServiceAPIService) GetInferenceService(ctx context.Context func (s *ModelRegistryServiceAPIService) GetInferenceServiceModel(ctx context.Context, inferenceserviceId string) (ImplResponse, error) { result, err := s.coreApi.GetRegisteredModelByInferenceService(inferenceserviceId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -262,13 +245,11 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServiceModel(ctx context.Co func (s *ModelRegistryServiceAPIService) GetInferenceServiceServes(ctx context.Context, inferenceserviceId string, name string, externalID string, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetServeModels(listOpts, apiutils.StrPtr(inferenceserviceId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -278,8 +259,7 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServiceServes(ctx context.C func (s *ModelRegistryServiceAPIService) GetInferenceServiceVersion(ctx context.Context, inferenceserviceId string) (ImplResponse, error) { result, err := s.coreApi.GetModelVersionByInferenceService(inferenceserviceId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -289,13 +269,11 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServiceVersion(ctx context. func (s *ModelRegistryServiceAPIService) GetInferenceServices(ctx context.Context, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetInferenceServices(listOpts, nil, nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -305,8 +283,7 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServices(ctx context.Contex func (s *ModelRegistryServiceAPIService) GetModelArtifact(ctx context.Context, modelartifactId string) (ImplResponse, error) { result, err := s.coreApi.GetModelArtifactById(modelartifactId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -316,13 +293,11 @@ func (s *ModelRegistryServiceAPIService) GetModelArtifact(ctx context.Context, m func (s *ModelRegistryServiceAPIService) GetModelArtifacts(ctx context.Context, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetModelArtifacts(listOpts, nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -332,8 +307,7 @@ func (s *ModelRegistryServiceAPIService) GetModelArtifacts(ctx context.Context, func (s *ModelRegistryServiceAPIService) GetModelVersion(ctx context.Context, modelversionId string) (ImplResponse, error) { result, err := s.coreApi.GetModelVersionById(modelversionId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -345,13 +319,11 @@ func (s *ModelRegistryServiceAPIService) GetModelVersionArtifacts(ctx context.Co // TODO externalID unused listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetArtifacts(listOpts, apiutils.StrPtr(modelversionId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -361,13 +333,11 @@ func (s *ModelRegistryServiceAPIService) GetModelVersionArtifacts(ctx context.Co func (s *ModelRegistryServiceAPIService) GetModelVersions(ctx context.Context, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetModelVersions(listOpts, nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -377,8 +347,7 @@ func (s *ModelRegistryServiceAPIService) GetModelVersions(ctx context.Context, p func (s *ModelRegistryServiceAPIService) GetRegisteredModel(ctx context.Context, registeredmodelId string) (ImplResponse, error) { result, err := s.coreApi.GetRegisteredModelById(registeredmodelId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -390,13 +359,11 @@ func (s *ModelRegistryServiceAPIService) GetRegisteredModelVersions(ctx context. // TODO externalID unused listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetModelVersions(listOpts, apiutils.StrPtr(registeredmodelId)) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -406,13 +373,11 @@ func (s *ModelRegistryServiceAPIService) GetRegisteredModelVersions(ctx context. func (s *ModelRegistryServiceAPIService) GetRegisteredModels(ctx context.Context, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetRegisteredModels(listOpts) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -422,8 +387,7 @@ func (s *ModelRegistryServiceAPIService) GetRegisteredModels(ctx context.Context func (s *ModelRegistryServiceAPIService) GetServingEnvironment(ctx context.Context, servingenvironmentId string) (ImplResponse, error) { result, err := s.coreApi.GetServingEnvironmentById(servingenvironmentId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO: return Response(http.StatusUnauthorized, Error{}), nil @@ -433,13 +397,11 @@ func (s *ModelRegistryServiceAPIService) GetServingEnvironment(ctx context.Conte func (s *ModelRegistryServiceAPIService) GetServingEnvironments(ctx context.Context, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } result, err := s.coreApi.GetServingEnvironments(listOpts) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -449,22 +411,20 @@ func (s *ModelRegistryServiceAPIService) GetServingEnvironments(ctx context.Cont func (s *ModelRegistryServiceAPIService) UpdateInferenceService(ctx context.Context, inferenceserviceId string, inferenceServiceUpdate model.InferenceServiceUpdate) (ImplResponse, error) { entity, err := s.converter.ConvertInferenceServiceUpdate(&inferenceServiceUpdate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } entity.Id = &inferenceserviceId existing, err := s.coreApi.GetInferenceServiceById(inferenceserviceId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } update, err := s.reconciler.UpdateExistingInferenceService(converter.NewOpenapiUpdateWrapper(existing, entity)) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertInferenceService(&update) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -474,22 +434,20 @@ func (s *ModelRegistryServiceAPIService) UpdateInferenceService(ctx context.Cont func (s *ModelRegistryServiceAPIService) UpdateModelArtifact(ctx context.Context, modelartifactId string, modelArtifactUpdate model.ModelArtifactUpdate) (ImplResponse, error) { modelArtifact, err := s.converter.ConvertModelArtifactUpdate(&modelArtifactUpdate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } modelArtifact.Id = &modelartifactId existing, err := s.coreApi.GetModelArtifactById(modelartifactId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } update, err := s.reconciler.UpdateExistingModelArtifact(converter.NewOpenapiUpdateWrapper(existing, modelArtifact)) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertModelArtifact(&update, nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -499,22 +457,20 @@ func (s *ModelRegistryServiceAPIService) UpdateModelArtifact(ctx context.Context func (s *ModelRegistryServiceAPIService) UpdateModelVersion(ctx context.Context, modelversionId string, modelVersionUpdate model.ModelVersionUpdate) (ImplResponse, error) { modelVersion, err := s.converter.ConvertModelVersionUpdate(&modelVersionUpdate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } modelVersion.Id = &modelversionId existing, err := s.coreApi.GetModelVersionById(modelversionId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } update, err := s.reconciler.UpdateExistingModelVersion(converter.NewOpenapiUpdateWrapper(existing, modelVersion)) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertModelVersion(&update, nil) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -524,22 +480,20 @@ func (s *ModelRegistryServiceAPIService) UpdateModelVersion(ctx context.Context, func (s *ModelRegistryServiceAPIService) UpdateRegisteredModel(ctx context.Context, registeredmodelId string, registeredModelUpdate model.RegisteredModelUpdate) (ImplResponse, error) { registeredModel, err := s.converter.ConvertRegisteredModelUpdate(®isteredModelUpdate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } registeredModel.Id = ®isteredmodelId existing, err := s.coreApi.GetRegisteredModelById(registeredmodelId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } update, err := s.reconciler.UpdateExistingRegisteredModel(converter.NewOpenapiUpdateWrapper(existing, registeredModel)) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertRegisteredModel(&update) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil @@ -549,22 +503,20 @@ func (s *ModelRegistryServiceAPIService) UpdateRegisteredModel(ctx context.Conte func (s *ModelRegistryServiceAPIService) UpdateServingEnvironment(ctx context.Context, servingenvironmentId string, servingEnvironmentUpdate model.ServingEnvironmentUpdate) (ImplResponse, error) { entity, err := s.converter.ConvertServingEnvironmentUpdate(&servingEnvironmentUpdate) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } entity.Id = &servingenvironmentId existing, err := s.coreApi.GetServingEnvironmentById(servingenvironmentId) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } update, err := s.reconciler.UpdateExistingServingEnvironment(converter.NewOpenapiUpdateWrapper(existing, entity)) if err != nil { - return Response(http.StatusBadRequest, model.Error{Message: err.Error()}), nil + return ErrorResponse(http.StatusBadRequest, err), err } result, err := s.coreApi.UpsertServingEnvironment(&update) if err != nil { - status := api.ErrToStatus(err) - return Response(status, model.Error{Message: err.Error()}), nil + return ErrorResponse(api.ErrToStatus(err), err), err } return Response(http.StatusOK, result), nil // TODO return Response(http.StatusUnauthorized, Error{}), nil diff --git a/internal/server/openapi/error.go b/internal/server/openapi/error.go index 670c7ca9d..34c18d408 100644 --- a/internal/server/openapi/error.go +++ b/internal/server/openapi/error.go @@ -15,10 +15,8 @@ import ( "net/http" ) -var ( - // ErrTypeAssertionError is thrown when type an interface does not match the asserted type - ErrTypeAssertionError = errors.New("unable to assert type") -) +// ErrTypeAssertionError is thrown when type an interface does not match the asserted type +var ErrTypeAssertionError = errors.New("unable to assert type") // ParsingError indicates that an error has occurred when parsing request parameters type ParsingError struct { @@ -51,12 +49,12 @@ type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error, result func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, result *ImplResponse) { if _, ok := err.(*ParsingError); ok { // Handle parsing errors - EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusBadRequest), w) + EncodeJSONResponse(ErrorResponse(http.StatusBadRequest, err).Body, func(i int) *int { return &i }(http.StatusBadRequest), w) } else if _, ok := err.(*RequiredError); ok { // Handle missing required errors - EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusUnprocessableEntity), w) + EncodeJSONResponse(ErrorResponse(http.StatusBadRequest, err).Body, func(i int) *int { return &i }(http.StatusUnprocessableEntity), w) } else { // Handle all other errors - EncodeJSONResponse(err.Error(), &result.Code, w) + EncodeJSONResponse(result.Body, &result.Code, w) } } diff --git a/internal/server/openapi/helpers.go b/internal/server/openapi/helpers.go index 57e2b44b2..0a19264fc 100644 --- a/internal/server/openapi/helpers.go +++ b/internal/server/openapi/helpers.go @@ -11,6 +11,8 @@ package openapi import ( "reflect" + + model "github.com/kubeflow/model-registry/pkg/openapi" ) // Response return a ImplResponse struct filled @@ -21,6 +23,13 @@ func Response(code int, body interface{}) ImplResponse { } } +func ErrorResponse(code int, err error) ImplResponse { + return ImplResponse{ + Code: code, + Body: model.Error{Message: err.Error()}, + } +} + // IsZeroValue checks if the val is the zero-ed value. func IsZeroValue(val interface{}) bool { return val == nil || reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) diff --git a/test/robot/Regression.robot b/test/robot/Regression.robot new file mode 100644 index 000000000..fed9d76b2 --- /dev/null +++ b/test/robot/Regression.robot @@ -0,0 +1,16 @@ +*** Settings *** +Resource Setup.resource +Resource MRkeywords.resource +Test Setup Test Setup with dummy data + + +*** Comments *** +Regression tests for Model Registry + + +*** Test Cases *** +As a MLOps engineer if I try to store a malformed RegisteredModel I get a structured error message + ${rm} Create Dictionary name="model" ext_id=123 + ${err} POST url=http://${MR_HOST}:${MR_PORT}/api/model_registry/v1alpha3/registered_models json=&{rm} expected_status=400 + ${rm_err} Create Dictionary code= message=json: unknown field "ext_id" + And Should be equal ${rm_err} ${err.json()}