From bb75335bdb366c47f699dc39777b474034281513 Mon Sep 17 00:00:00 2001 From: Ewan Harris Date: Mon, 26 Feb 2024 12:45:40 +0000 Subject: [PATCH] fix: do not call read auth model on batchcheck and write --- client/client.go | 48 +++++------- client/client_test.go | 171 ++++++++++++++++++------------------------ 2 files changed, 88 insertions(+), 131 deletions(-) diff --git a/client/client.go b/client/client.go index 5b40317..796e902 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,7 @@ package client import ( _context "context" "encoding/json" + "errors" "fmt" "math" _nethttp "net/http" @@ -448,19 +449,6 @@ func (client *OpenFgaClient) getAuthorizationModelId(authorizationModelId *strin return &modelId, nil } -// helper function to validate the connection (i.e., get token) -func (client *OpenFgaClient) checkValidApiConnection(ctx _context.Context, authorizationModelId *string) error { - if authorizationModelId != nil && *authorizationModelId != "" { - _, _, err := client.OpenFgaApi.ReadAuthorizationModel(ctx, *authorizationModelId).Execute() - return err - } else { - _, err := client.ReadAuthorizationModels(ctx).Options(ClientReadAuthorizationModelsOptions{ - PageSize: fgaSdk.PtrInt32(1), - }).Execute() - return err - } -} - /* Stores */ // / ListStores @@ -1313,9 +1301,9 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface Deletes: []ClientWriteRequestDeleteResponse{}, } - authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride()) - if err != nil { - return nil, err + authorizationModelId := request.GetAuthorizationModelIdOverride() + if authorizationModelId != nil && *authorizationModelId != "" && !internalutils.IsWellFormedUlidString(*authorizationModelId) { + return nil, FgaInvalidError{param: "AuthorizationModelId", description: "ULID"} } // Unless explicitly disabled, transaction mode is enabled @@ -1387,10 +1375,6 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface } writeGroup, ctx := errgroup.WithContext(request.GetContext()) - err = client.checkValidApiConnection(ctx, authorizationModelId) - if err != nil { - return nil, err - } writeGroup.SetLimit(int(maxParallelReqs)) writeResponses := make([]ClientWriteResponse, len(writeChunks)) @@ -1681,10 +1665,12 @@ func (client *OpenFgaClient) CheckExecute(request SdkClientCheckRequestInterface contextualTuples = append(contextualTuples, (request.GetBody().ContextualTuples)[index]) } } - authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride()) - if err != nil { - return nil, err + authorizationModelId := request.GetAuthorizationModelIdOverride() + + if authorizationModelId != nil && *authorizationModelId != "" && !internalutils.IsWellFormedUlidString(*authorizationModelId) { + return nil, FgaInvalidError{param: "AuthorizationModelId", description: "ULID"} } + requestBody := fgaSdk.CheckRequest{ TupleKey: fgaSdk.CheckRequestTupleKey{ User: request.GetBody().User, @@ -1787,16 +1773,12 @@ func (client *OpenFgaClient) BatchCheckExecute(request SdkClientBatchCheckReques group.SetLimit(maxParallelReqs) var numOfChecks = len(*request.GetBody()) response := make(ClientBatchCheckResponse, numOfChecks) - authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride()) - if err != nil { - return nil, err + authorizationModelId := request.GetAuthorizationModelIdOverride() + + if authorizationModelId != nil && *authorizationModelId != "" && !internalutils.IsWellFormedUlidString(*authorizationModelId) { + return nil, FgaInvalidError{param: "AuthorizationModelId", description: "ULID"} } - group.Go(func() error { - // if the connection is probelmatic, we will return error to the overall - // response rather than individual response - return client.checkValidApiConnection(ctx, authorizationModelId) - }) for index, checkBody := range *request.GetBody() { index, checkBody := index, checkBody group.Go(func() error { @@ -1809,6 +1791,10 @@ func (client *OpenFgaClient) BatchCheckExecute(request SdkClientBatchCheckReques }, }) + if errors.Is(err, fgaSdk.FgaApiAuthenticationError{}) { + return err + } + response[index] = ClientBatchCheckSingleResponse{ Request: checkBody, ClientCheckResponse: *singleResponse, diff --git a/client/client_test.go b/client/client_test.go index 27b60fe..38911d2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -20,7 +20,7 @@ import ( "testing" "github.com/jarcoal/httpmock" - "github.com/openfga/go-sdk" + openfga "github.com/openfga/go-sdk" . "github.com/openfga/go-sdk/client" ) @@ -1466,23 +1466,14 @@ func TestOpenFgaClient(t *testing.T) { return resp, nil }, ) - httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId), - func(req *http.Request) (*http.Response, error) { - return httpmock.NewStringResponse(http.StatusOK, ""), nil - }, - ) - httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, authModelId), - func(req *http.Request) (*http.Response, error) { - return httpmock.NewStringResponse(http.StatusOK, ""), nil - }, - ) + got, err := fgaClient.BatchCheck(context.Background()).Body(requestBody).Options(options).Execute() if err != nil { t.Fatalf("%v", err) } - if httpmock.GetTotalCallCount() != 5 { - t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check + 1 call to validate auth model, got %v", test.Name, 4, httpmock.GetTotalCallCount()) + if httpmock.GetTotalCallCount() != 4 { + t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check got %v", test.Name, 4, httpmock.GetTotalCallCount()) } if len(*got) != len(requestBody) { @@ -1524,81 +1515,71 @@ func TestOpenFgaClient(t *testing.T) { }) - t.Run("BatchCheckConnectionProblem", func(t *testing.T) { - test := TestDefinition{ - Name: "Check", - JsonResponse: `{"allowed":true, "resolution":""}`, - ResponseStatus: http.StatusOK, - Method: http.MethodPost, - RequestPath: "check", - } - requestBody := ClientBatchCheckBody{{ - User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", - Relation: "viewer", - Object: "document:roadmap", - ContextualTuples: []ClientContextualTupleKey{{ - User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", - Relation: "editor", - Object: "document:roadmap", - }}, - }, { - User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", - Relation: "admin", - Object: "document:roadmap", - ContextualTuples: []ClientContextualTupleKey{{ - User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", - Relation: "editor", - Object: "document:roadmap", - }}, - }, { - User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", - Relation: "creator", - Object: "document:roadmap", - }, { - User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", - Relation: "deleter", - Object: "document:roadmap", - }} - - const authModelId = "01GAHCE4YVKPQEKZQHT2R89MQV" - - options := ClientBatchCheckOptions{ - AuthorizationModelId: openfga.PtrString(authModelId), - MaxParallelRequests: openfga.PtrInt32(5), - } - - var expectedResponse openfga.CheckResponse - if err := json.Unmarshal([]byte(test.JsonResponse), &expectedResponse); err != nil { - t.Fatalf("%v", err) - } - - httpmock.Activate() - defer httpmock.DeactivateAndReset() - httpmock.RegisterResponder(test.Method, fmt.Sprintf("%s/stores/%s/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, test.RequestPath), - func(req *http.Request) (*http.Response, error) { - resp, err := httpmock.NewJsonResponse(test.ResponseStatus, expectedResponse) - if err != nil { - return httpmock.NewStringResponse(http.StatusInternalServerError, ""), nil - } - return resp, nil - }, - ) - httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId), - func(req *http.Request) (*http.Response, error) { - return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil - }, - ) - httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, authModelId), - func(req *http.Request) (*http.Response, error) { - return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil - }, - ) - _, err := fgaClient.BatchCheck(context.Background()).Body(requestBody).Options(options).Execute() - if err == nil { - t.Fatalf("Expect error but there is none") - } - - }) + // t.Run("BatchCheckConnectionProblem", func(t *testing.T) { + // test := TestDefinition{ + // Name: "Check", + // JsonResponse: `{"allowed":true, "resolution":""}`, + // ResponseStatus: http.StatusOK, + // Method: http.MethodPost, + // RequestPath: "check", + // } + // requestBody := ClientBatchCheckBody{{ + // User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", + // Relation: "viewer", + // Object: "document:roadmap", + // ContextualTuples: []ClientContextualTupleKey{{ + // User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", + // Relation: "editor", + // Object: "document:roadmap", + // }}, + // }, { + // User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", + // Relation: "admin", + // Object: "document:roadmap", + // ContextualTuples: []ClientContextualTupleKey{{ + // User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", + // Relation: "editor", + // Object: "document:roadmap", + // }}, + // }, { + // User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", + // Relation: "creator", + // Object: "document:roadmap", + // }, { + // User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b", + // Relation: "deleter", + // Object: "document:roadmap", + // }} + + // const authModelId = "01GAHCE4YVKPQEKZQHT2R89MQV" + + // options := ClientBatchCheckOptions{ + // AuthorizationModelId: openfga.PtrString(authModelId), + // MaxParallelRequests: openfga.PtrInt32(5), + // } + + // var expectedResponse openfga.CheckResponse + // if err := json.Unmarshal([]byte(test.JsonResponse), &expectedResponse); err != nil { + // t.Fatalf("%v", err) + // } + + // httpmock.Activate() + // defer httpmock.DeactivateAndReset() + // httpmock.RegisterResponder(test.Method, fmt.Sprintf("%s/stores/%s/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, test.RequestPath), + // func(req *http.Request) (*http.Response, error) { + // resp, err := httpmock.NewJsonResponse(test.ResponseStatus, expectedResponse) + // if err != nil { + // return httpmock.NewStringResponse(http.StatusInternalServerError, ""), nil + // } + // return resp, nil + // }, + // ) + + // _, err := fgaClient.BatchCheck(context.Background()).Body(requestBody).Options(options).Execute() + // if err == nil { + // t.Fatalf("Expect error but there is none") + // } + // }) t.Run("Expand", func(t *testing.T) { test := TestDefinition{ @@ -1784,16 +1765,6 @@ func TestOpenFgaClient(t *testing.T) { return resp, nil }, ) - httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId), - func(req *http.Request) (*http.Response, error) { - return httpmock.NewStringResponse(http.StatusOK, ""), nil - }, - ) - httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, authModelId), - func(req *http.Request) (*http.Response, error) { - return httpmock.NewStringResponse(http.StatusOK, ""), nil - }, - ) got, err := fgaClient.ListRelations(context.Background()). Body(requestBody). @@ -1803,8 +1774,8 @@ func TestOpenFgaClient(t *testing.T) { t.Fatalf("%v", err) } - if httpmock.GetTotalCallCount() != 5 { - t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check + 1 call to validate auth model, got %v", test.Name, 4, httpmock.GetTotalCallCount()) + if httpmock.GetTotalCallCount() != 4 { + t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check got %v", test.Name, 4, httpmock.GetTotalCallCount()) } _, err = got.MarshalJSON()