Skip to content

Commit

Permalink
fix: do not call read auth model on batchcheck and write
Browse files Browse the repository at this point in the history
  • Loading branch information
ewanharris committed Feb 26, 2024
1 parent e5d8d41 commit bb75335
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 131 deletions.
48 changes: 17 additions & 31 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package client
import (
_context "context"
"encoding/json"
"errors"
"fmt"
"math"
_nethttp "net/http"
Expand Down Expand Up @@ -448,19 +449,6 @@ func (client *OpenFgaClient) getAuthorizationModelId(authorizationModelId *strin
return &modelId, nil
}

// helper function to validate the connection (i.e., get token)
func (client *OpenFgaClient) checkValidApiConnection(ctx _context.Context, authorizationModelId *string) error {
if authorizationModelId != nil && *authorizationModelId != "" {
_, _, err := client.OpenFgaApi.ReadAuthorizationModel(ctx, *authorizationModelId).Execute()
return err
} else {
_, err := client.ReadAuthorizationModels(ctx).Options(ClientReadAuthorizationModelsOptions{
PageSize: fgaSdk.PtrInt32(1),
}).Execute()
return err
}
}

/* Stores */

// / ListStores
Expand Down Expand Up @@ -1313,9 +1301,9 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
Deletes: []ClientWriteRequestDeleteResponse{},
}

authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())
if err != nil {
return nil, err
authorizationModelId := request.GetAuthorizationModelIdOverride()
if authorizationModelId != nil && *authorizationModelId != "" && !internalutils.IsWellFormedUlidString(*authorizationModelId) {
return nil, FgaInvalidError{param: "AuthorizationModelId", description: "ULID"}
}

// Unless explicitly disabled, transaction mode is enabled
Expand Down Expand Up @@ -1387,10 +1375,6 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
}

writeGroup, ctx := errgroup.WithContext(request.GetContext())
err = client.checkValidApiConnection(ctx, authorizationModelId)
if err != nil {
return nil, err
}

writeGroup.SetLimit(int(maxParallelReqs))
writeResponses := make([]ClientWriteResponse, len(writeChunks))
Expand Down Expand Up @@ -1681,10 +1665,12 @@ func (client *OpenFgaClient) CheckExecute(request SdkClientCheckRequestInterface
contextualTuples = append(contextualTuples, (request.GetBody().ContextualTuples)[index])
}
}
authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())
if err != nil {
return nil, err
authorizationModelId := request.GetAuthorizationModelIdOverride()

if authorizationModelId != nil && *authorizationModelId != "" && !internalutils.IsWellFormedUlidString(*authorizationModelId) {
return nil, FgaInvalidError{param: "AuthorizationModelId", description: "ULID"}
}

requestBody := fgaSdk.CheckRequest{
TupleKey: fgaSdk.CheckRequestTupleKey{
User: request.GetBody().User,
Expand Down Expand Up @@ -1787,16 +1773,12 @@ func (client *OpenFgaClient) BatchCheckExecute(request SdkClientBatchCheckReques
group.SetLimit(maxParallelReqs)
var numOfChecks = len(*request.GetBody())
response := make(ClientBatchCheckResponse, numOfChecks)
authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())
if err != nil {
return nil, err
authorizationModelId := request.GetAuthorizationModelIdOverride()

if authorizationModelId != nil && *authorizationModelId != "" && !internalutils.IsWellFormedUlidString(*authorizationModelId) {
return nil, FgaInvalidError{param: "AuthorizationModelId", description: "ULID"}
}

group.Go(func() error {
// if the connection is probelmatic, we will return error to the overall
// response rather than individual response
return client.checkValidApiConnection(ctx, authorizationModelId)
})
for index, checkBody := range *request.GetBody() {
index, checkBody := index, checkBody
group.Go(func() error {
Expand All @@ -1809,6 +1791,10 @@ func (client *OpenFgaClient) BatchCheckExecute(request SdkClientBatchCheckReques
},
})

if errors.Is(err, fgaSdk.FgaApiAuthenticationError{}) {
return err
}

response[index] = ClientBatchCheckSingleResponse{
Request: checkBody,
ClientCheckResponse: *singleResponse,
Expand Down
171 changes: 71 additions & 100 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"testing"

"github.com/jarcoal/httpmock"
"github.com/openfga/go-sdk"
openfga "github.com/openfga/go-sdk"
. "github.com/openfga/go-sdk/client"
)

Expand Down Expand Up @@ -1466,23 +1466,14 @@ func TestOpenFgaClient(t *testing.T) {
return resp, nil
},
)
httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId),
func(req *http.Request) (*http.Response, error) {
return httpmock.NewStringResponse(http.StatusOK, ""), nil
},
)
httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, authModelId),
func(req *http.Request) (*http.Response, error) {
return httpmock.NewStringResponse(http.StatusOK, ""), nil
},
)

got, err := fgaClient.BatchCheck(context.Background()).Body(requestBody).Options(options).Execute()
if err != nil {
t.Fatalf("%v", err)
}

if httpmock.GetTotalCallCount() != 5 {
t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check + 1 call to validate auth model, got %v", test.Name, 4, httpmock.GetTotalCallCount())
if httpmock.GetTotalCallCount() != 4 {
t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check got %v", test.Name, 4, httpmock.GetTotalCallCount())
}

if len(*got) != len(requestBody) {
Expand Down Expand Up @@ -1524,81 +1515,71 @@ func TestOpenFgaClient(t *testing.T) {

})

t.Run("BatchCheckConnectionProblem", func(t *testing.T) {
test := TestDefinition{
Name: "Check",
JsonResponse: `{"allowed":true, "resolution":""}`,
ResponseStatus: http.StatusOK,
Method: http.MethodPost,
RequestPath: "check",
}
requestBody := ClientBatchCheckBody{{
User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
Relation: "viewer",
Object: "document:roadmap",
ContextualTuples: []ClientContextualTupleKey{{
User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
Relation: "editor",
Object: "document:roadmap",
}},
}, {
User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
Relation: "admin",
Object: "document:roadmap",
ContextualTuples: []ClientContextualTupleKey{{
User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
Relation: "editor",
Object: "document:roadmap",
}},
}, {
User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
Relation: "creator",
Object: "document:roadmap",
}, {
User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
Relation: "deleter",
Object: "document:roadmap",
}}

const authModelId = "01GAHCE4YVKPQEKZQHT2R89MQV"

options := ClientBatchCheckOptions{
AuthorizationModelId: openfga.PtrString(authModelId),
MaxParallelRequests: openfga.PtrInt32(5),
}

var expectedResponse openfga.CheckResponse
if err := json.Unmarshal([]byte(test.JsonResponse), &expectedResponse); err != nil {
t.Fatalf("%v", err)
}

httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder(test.Method, fmt.Sprintf("%s/stores/%s/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, test.RequestPath),
func(req *http.Request) (*http.Response, error) {
resp, err := httpmock.NewJsonResponse(test.ResponseStatus, expectedResponse)
if err != nil {
return httpmock.NewStringResponse(http.StatusInternalServerError, ""), nil
}
return resp, nil
},
)
httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId),
func(req *http.Request) (*http.Response, error) {
return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil
},
)
httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, authModelId),
func(req *http.Request) (*http.Response, error) {
return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil
},
)
_, err := fgaClient.BatchCheck(context.Background()).Body(requestBody).Options(options).Execute()
if err == nil {
t.Fatalf("Expect error but there is none")
}

})
// t.Run("BatchCheckConnectionProblem", func(t *testing.T) {
// test := TestDefinition{
// Name: "Check",
// JsonResponse: `{"allowed":true, "resolution":""}`,
// ResponseStatus: http.StatusOK,
// Method: http.MethodPost,
// RequestPath: "check",
// }
// requestBody := ClientBatchCheckBody{{
// User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
// Relation: "viewer",
// Object: "document:roadmap",
// ContextualTuples: []ClientContextualTupleKey{{
// User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
// Relation: "editor",
// Object: "document:roadmap",
// }},
// }, {
// User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
// Relation: "admin",
// Object: "document:roadmap",
// ContextualTuples: []ClientContextualTupleKey{{
// User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
// Relation: "editor",
// Object: "document:roadmap",
// }},
// }, {
// User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
// Relation: "creator",
// Object: "document:roadmap",
// }, {
// User: "user:81684243-9356-4421-8fbf-a4f8d36aa31b",
// Relation: "deleter",
// Object: "document:roadmap",
// }}

// const authModelId = "01GAHCE4YVKPQEKZQHT2R89MQV"

// options := ClientBatchCheckOptions{
// AuthorizationModelId: openfga.PtrString(authModelId),
// MaxParallelRequests: openfga.PtrInt32(5),
// }

// var expectedResponse openfga.CheckResponse
// if err := json.Unmarshal([]byte(test.JsonResponse), &expectedResponse); err != nil {
// t.Fatalf("%v", err)
// }

// httpmock.Activate()
// defer httpmock.DeactivateAndReset()
// httpmock.RegisterResponder(test.Method, fmt.Sprintf("%s/stores/%s/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, test.RequestPath),
// func(req *http.Request) (*http.Response, error) {
// resp, err := httpmock.NewJsonResponse(test.ResponseStatus, expectedResponse)
// if err != nil {
// return httpmock.NewStringResponse(http.StatusInternalServerError, ""), nil
// }
// return resp, nil
// },
// )

// _, err := fgaClient.BatchCheck(context.Background()).Body(requestBody).Options(options).Execute()
// if err == nil {
// t.Fatalf("Expect error but there is none")
// }
// })

t.Run("Expand", func(t *testing.T) {
test := TestDefinition{
Expand Down Expand Up @@ -1784,16 +1765,6 @@ func TestOpenFgaClient(t *testing.T) {
return resp, nil
},
)
httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId),
func(req *http.Request) (*http.Response, error) {
return httpmock.NewStringResponse(http.StatusOK, ""), nil
},
)
httpmock.RegisterResponder("GET", fmt.Sprintf("%s/stores/%s/authorization-models/%s", fgaClient.GetConfig().ApiUrl, fgaClient.GetConfig().StoreId, authModelId),
func(req *http.Request) (*http.Response, error) {
return httpmock.NewStringResponse(http.StatusOK, ""), nil
},
)

got, err := fgaClient.ListRelations(context.Background()).
Body(requestBody).
Expand All @@ -1803,8 +1774,8 @@ func TestOpenFgaClient(t *testing.T) {
t.Fatalf("%v", err)
}

if httpmock.GetTotalCallCount() != 5 {
t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check + 1 call to validate auth model, got %v", test.Name, 4, httpmock.GetTotalCallCount())
if httpmock.GetTotalCallCount() != 4 {
t.Fatalf("OpenFgaClient.%v() - wanted %v calls to /check got %v", test.Name, 4, httpmock.GetTotalCallCount())
}

_, err = got.MarshalJSON()
Expand Down

0 comments on commit bb75335

Please sign in to comment.