Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ability to serve mocked data for MR API calls via cli flag #377

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clients/ui/bff/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CONTAINER_TOOL ?= docker
IMG ?= model-registry-bff:latest
PORT ?= 4000
MOCK_K8S_CLIENT ?= false
MOCK_MR_CLIENT ?= false

.PHONY: all
all: build
Expand Down Expand Up @@ -32,7 +33,7 @@ build: fmt vet test

.PHONY: run
run: fmt vet
go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT)
go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT)

.PHONY: docker-build
docker-build:
Expand Down
4 changes: 2 additions & 2 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ After building it, you can run our app with:
```shell
make run
```
If you want to use a different port or mock kubernetes client, useful for front-end development, you can run:
If you want to use a different port, mock kubernetes client or model registry client - useful for front-end development, you can run:
```shell
make run PORT=8000 MOCK_K8S_CLIENT=true
make run PORT=8000 MOCK_K8S_CLIENT=true MOCK_MR_CLIENT=true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a review, it's more of a question, are we gonna get 8000 as the default port? We have 4000 in the frontend for the api, if so, I'll do a follow up PR to change the references there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, the default port seems to be PORT ?= 4000 (on Makefile)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But on those instructions, I just to show how to do with an alternative one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes... if you leave out the PORT param it defaults to 4000 - I think this is just a doc issue of "how to I make that clearer"

```

# Building and Deploying
Expand Down
30 changes: 22 additions & 8 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ const (
)

type App struct {
config config.EnvConfig
logger *slog.Logger
models data.Models
kubernetesClient integrations.KubernetesClientInterface
config config.EnvConfig
logger *slog.Logger
models data.Models
kubernetesClient integrations.KubernetesClientInterface
modelRegistryClient data.ModelRegistryClientInterface
}

func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
Expand All @@ -43,10 +44,23 @@ func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
return nil, fmt.Errorf("failed to create Kubernetes client: %w", err)
}

var mrClient data.ModelRegistryClientInterface

if cfg.MockMRClient {
mrClient, err = mocks.NewModelRegistryClient(logger)
} else {
mrClient, err = data.NewModelRegistryClient(logger)
}

if err != nil {
return nil, fmt.Errorf("failed to create ModelRegistry client: %w", err)
}

app := &App{
config: cfg,
logger: logger,
kubernetesClient: k8sClient,
config: cfg,
logger: logger,
kubernetesClient: k8sClient,
modelRegistryClient: mrClient,
}
return app, nil
}
Expand All @@ -59,7 +73,7 @@ func (app *App) Routes() http.Handler {

// HTTP client routes
router.GET(HealthCheckPath, app.HealthcheckHandler)
router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetRegisteredModelsHandler))
router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetAllRegisteredModelsHandler))
router.GET(RegisteredModelPath, app.AttachRESTClient(app.GetRegisteredModelHandler))
router.POST(RegisteredModelsPath, app.AttachRESTClient(app.CreateRegisteredModelHandler))

Expand Down
2 changes: 2 additions & 0 deletions clients/ui/bff/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

type Envelope map[string]interface{}

type TypedEnvelope[T any] map[string]T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexcreasy in a next iteration shall we remove the untyped one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ederign yes, I'd like to, the only reason I didn't just change it was for speed right now :)


func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error {

js, err := json.MarshalIndent(data, "", "\t")
Expand Down
13 changes: 6 additions & 7 deletions clients/ui/bff/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,27 @@ import (
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/data"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"github.com/kubeflow/model-registry/ui/bff/validation"
"net/http"
)

func (app *App) GetRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

modelList, err := data.GetAllRegisteredModels(client)
modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

modelRegistryRes := Envelope{
"registered_models": modelList,
"registered_model_list": modelList,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
Expand Down Expand Up @@ -60,7 +59,7 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
return
}

createdModel, err := data.CreateRegisteredModel(client, jsonData)
createdModel, err := app.modelRegistryClient.CreateRegisteredModel(client, jsonData)
if err != nil {
var httpErr *integrations.HTTPError
if errors.As(err, &httpErr) {
Expand Down Expand Up @@ -91,13 +90,13 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request
return
}

model, err := data.GetRegisteredModel(client, ps.ByName(RegisteredModelId))
model, err := app.modelRegistryClient.GetRegisteredModel(client, ps.ByName(RegisteredModelId))
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

if _, ok := model.GetNameOk(); !ok {
if _, ok := model.GetIdOk(); !ok {
app.notFoundResponse(w, r)
return
}
Expand Down
135 changes: 135 additions & 0 deletions clients/ui/bff/api/registered_models_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package api

import (
"bytes"
"context"
"encoding/json"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/internals/mocks"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
)

func TestGetRegisteredModelHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

req, err := http.NewRequest(http.MethodGet,
"/api/v1/model-registry/model-registry/registered_models/1", nil)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

testApp.GetRegisteredModelHandler(rr, req, nil)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelRes TypedEnvelope[openapi.RegisteredModel]
err = json.Unmarshal(body, &registeredModelRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

var expected = TypedEnvelope[openapi.RegisteredModel]{
"registered_model": mocks.GetRegisteredModelMocks()[0],
}

//TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values
// this issue is in the test only
assert.Equal(t, expected["registered_model"].Name, registeredModelRes["registered_model"].Name)
}

func TestGetAllRegisteredModelsHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

req, err := http.NewRequest(http.MethodGet,
"/api/v1/model-registry/model-registry/registered_models", nil)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

testApp.GetAllRegisteredModelsHandler(rr, req, nil)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelsListRes TypedEnvelope[openapi.RegisteredModelList]
err = json.Unmarshal(body, &registeredModelsListRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

var expected = TypedEnvelope[openapi.RegisteredModelList]{
"registered_model_list": mocks.GetRegisteredModelListMock(),
}

assert.Equal(t, expected["registered_model_list"].Size, registeredModelsListRes["registered_model_list"].Size)
assert.Equal(t, expected["registered_model_list"].PageSize, registeredModelsListRes["registered_model_list"].PageSize)
assert.Equal(t, expected["registered_model_list"].NextPageToken, registeredModelsListRes["registered_model_list"].NextPageToken)
assert.Equal(t, len(expected["registered_model_list"].Items), len(registeredModelsListRes["registered_model_list"].Items))
}

func TestCreateRegisteredModelHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

newModel := openapi.NewRegisteredModelCreate("Model One")
newModelJSON, err := newModel.MarshalJSON()
assert.NoError(t, err)

reqBody := bytes.NewReader(newModelJSON)

req, err := http.NewRequest(http.MethodPost,
"/api/v1/model-registry/model-registry/registered_models", reqBody)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

testApp.CreateRegisteredModelHandler(rr, req, nil)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelRes openapi.RegisteredModel
err = json.Unmarshal(body, &registeredModelRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusCreated, rr.Code)

var expected = mocks.GetRegisteredModelMocks()[0]

assert.Equal(t, expected.Name, registeredModelRes.Name)
assert.NotEmpty(t, rs.Header.Get("location"))
}
1 change: 1 addition & 0 deletions clients/ui/bff/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func main() {
var cfg config.EnvConfig
flag.IntVar(&cfg.Port, "port", getEnvAsInt("PORT", 4000), "API server port")
flag.BoolVar(&cfg.MockK8Client, "mock-k8s-client", false, "Use mock Kubernetes client")
flag.BoolVar(&cfg.MockMRClient, "mock-mr-client", false, "Use mock Model Registry client")
flag.Parse()

logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
Expand Down
1 change: 1 addition & 0 deletions clients/ui/bff/config/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package config
type EnvConfig struct {
Port int
MockK8Client bool
MockMRClient bool
}
18 changes: 18 additions & 0 deletions clients/ui/bff/data/model_registry_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package data

import (
"log/slog"
)

type ModelRegistryClientInterface interface {
RegisteredModelInterface
}

type ModelRegistryClient struct {
logger *slog.Logger
RegisteredModel
}

func NewModelRegistryClient(logger *slog.Logger) (ModelRegistryClientInterface, error) {
return &ModelRegistryClient{logger: logger}, nil
}
16 changes: 13 additions & 3 deletions clients/ui/bff/data/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ import (

const registerModelPath = "/registered_models"

func GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
type RegisteredModelInterface interface {
GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error)
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
}

type RegisteredModel struct {
RegisteredModelInterface
}

func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {

responseData, err := client.GET(registerModelPath)
if err != nil {
Expand All @@ -26,7 +36,7 @@ func GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.R
return &modelList, nil
}

func CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) {
func (m RegisteredModel) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) {
responseData, err := client.POST(registerModelPath, bytes.NewBuffer(jsonData))

if err != nil {
Expand All @@ -41,7 +51,7 @@ func CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []b
return &model, nil
}

func GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) {
func (m RegisteredModel) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) {
path, err := url.JoinPath(registerModelPath, id)
if err != nil {
return nil, err
Expand Down
12 changes: 9 additions & 3 deletions clients/ui/bff/data/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockData, err := json.Marshal(expected)
assert.NoError(t, err)

mrClient := ModelRegistryClient{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registerModelPath).Return(mockData, nil)

actual, err := GetAllRegisteredModels(mockClient)
actual, err := mrClient.GetAllRegisteredModels(mockClient)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
Expand All @@ -39,13 +41,15 @@ func TestCreateRegisteredModel(t *testing.T) {
mockData, err := json.Marshal(expected)
assert.NoError(t, err)

mrClient := ModelRegistryClient{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("POST", registerModelPath, mock.Anything).Return(mockData, nil)

jsonInput, err := json.Marshal(expected)
assert.NoError(t, err)

actual, err := CreateRegisteredModel(mockClient, jsonInput)
actual, err := mrClient.CreateRegisteredModel(mockClient, jsonInput)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.Name, actual.Name)
Expand All @@ -62,10 +66,12 @@ func TestGetRegisteredModel(t *testing.T) {
mockData, err := json.Marshal(expected)
assert.NoError(t, err)

mrClient := ModelRegistryClient{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registerModelPath+"/"+expected.GetId()).Return(mockData, nil)

actual, err := GetRegisteredModel(mockClient, expected.GetId())
actual, err := mrClient.GetRegisteredModel(mockClient, expected.GetId())
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.Name, actual.Name)
Expand Down
Loading