From 4c06881cfeec64e0f8bc2fe0eed0e2b9df2fb009 Mon Sep 17 00:00:00 2001 From: Stanley Phu Date: Wed, 3 Jul 2024 17:03:38 -0700 Subject: [PATCH 1/7] Fix typo for IdempotencyKey json encoding --- pkg/organizations/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/organizations/client.go b/pkg/organizations/client.go index f68e9d27..c4a5fbf8 100644 --- a/pkg/organizations/client.go +++ b/pkg/organizations/client.go @@ -163,7 +163,7 @@ type CreateOrganizationOpts struct { DomainData []OrganizationDomainData `json:"domain_data"` // Optional unique identifier to ensure idempotency - IdempotencyKey string `json:"idempotency_iey,omitempty"` + IdempotencyKey string `json:"idempotency_key,omitempty"` } // UpdateOrganizationOpts contains the options to update an Organization. From 4c391f5bdc2dcccf09cbbc20761f045f377c1755 Mon Sep 17 00:00:00 2001 From: Stanley Phu Date: Wed, 3 Jul 2024 17:23:17 -0700 Subject: [PATCH 2/7] Add FGA client and methods --- pkg/fga/client.go | 822 ++++++++++++++ pkg/fga/client_live_example.go | 1860 ++++++++++++++++++++++++++++++++ pkg/fga/client_test.go | 1073 ++++++++++++++++++ pkg/fga/fga.go | 111 ++ pkg/fga/fga_test.go | 348 ++++++ 5 files changed, 4214 insertions(+) create mode 100644 pkg/fga/client.go create mode 100644 pkg/fga/client_live_example.go create mode 100644 pkg/fga/client_test.go create mode 100644 pkg/fga/fga.go create mode 100644 pkg/fga/fga_test.go diff --git a/pkg/fga/client.go b/pkg/fga/client.go new file mode 100644 index 00000000..0a9e0272 --- /dev/null +++ b/pkg/fga/client.go @@ -0,0 +1,822 @@ +package fga + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/google/go-querystring/query" + "github.com/workos/workos-go/v4/internal/workos" + "github.com/workos/workos-go/v4/pkg/common" + "github.com/workos/workos-go/v4/pkg/workos_errors" +) + +// ResponseLimit is the default number of records to limit a response to. +const ResponseLimit = 10 + +// Order represents the order of records. +type Order string + +// Constants that enumerate the available orders. +const ( + Asc Order = "asc" + Desc Order = "desc" +) + +// Client represents a client that performs FGA requests to the WorkOS API. +type Client struct { + // The WorkOS API Key. It can be found in https://dashboard.workos.com/api-keys. + APIKey string + + // The http.Client that is used to get Directory Sync records from WorkOS. + // Defaults to http.Client. + HTTPClient *http.Client + + // The endpoint to WorkOS API. Defaults to https://api.workos.com. + Endpoint string + + // The function used to encode in JSON. Defaults to json.Marshal. + JSONEncode func(v interface{}) ([]byte, error) + + once sync.Once +} + +func (c *Client) init() { + if c.HTTPClient == nil { + c.HTTPClient = &http.Client{Timeout: 10 * time.Second} + } + + if c.Endpoint == "" { + c.Endpoint = "https://api.workos.com" + } + + if c.JSONEncode == nil { + c.JSONEncode = json.Marshal + } +} + +// Objects +type Object struct { + // The type of the object. + ObjectType string `json:"object_type"` + + // The customer defined string identifier for this object. + ObjectId string `json:"object_id"` + + // Map containing additional information about this object. + Meta map[string]interface{} `json:"meta"` +} + +type GetObjectOpts struct { + // The type of the object. + ObjectType string + + // The customer defined string identifier for this object. + ObjectId string +} + +type ListObjectsOpts struct { + // The type of the object. + ObjectType string `url:"object_type,omitempty"` + + // Searchable text for an Object. Can be empty. + Search string `url:"search,omitempty"` + + // Maximum number of records to return. + Limit int `url:"limit,omitempty"` + + // The order in which to paginate records. + Order Order `url:"order,omitempty"` + + // Pagination cursor to receive records before a provided Object ID. + Before string `url:"before,omitempty"` + + // Pagination cursor to receive records after a provided Object ID. + After string `url:"after,omitempty"` +} + +// ListObjectsResponse describes the response structure when requesting Objects +type ListObjectsResponse struct { + // List of provisioned Objects. + Data []Object `json:"data"` + + // Cursor pagination options. + ListMetadata common.ListMetadata `json:"list_metadata"` +} + +type CreateObjectOpts struct { + // The type of the object. + ObjectType string `json:"object_type"` + + // The customer defined string identifier for this object. + ObjectId string `json:"object_id,omitempty"` + + // Map containing additional information about this object. + Meta map[string]interface{} `json:"meta,omitempty"` +} + +type UpdateObjectOpts struct { + // The type of the object. + ObjectType string `json:"object_type"` + + // The customer defined string identifier for this object. + ObjectId string `json:"object_id,omitempty"` + + // Map containing additional information about this object. + Meta map[string]interface{} `json:"meta,omitempty"` +} + +// DeleteObjectOpts contains the options to delete an object. +type DeleteObjectOpts struct { + // The type of the object. + ObjectType string + + // The customer defined string identifier for this object. + ObjectId string +} + +// Warrants +type Subject struct { + // The type of the subject. + ObjectType string `json:"object_type"` + + // The customer defined string identifier for this subject. + ObjectId string `json:"object_id"` + + // The relation of the subject. + Relation string `json:"relation,omitempty"` +} + +type Warrant struct { + // Type of object to assign a relation to. Must be an existing type. + ObjectType string `json:"object_type"` + + // Id of the object to assign a relation to. + ObjectId string `json:"object_id"` + + // Relation to assign to the object. + Relation string `json:"relation"` + + // Subject of the warrant + Subject Subject `json:"subject"` + + // Policy that must evaluate to true for warrant to be valid + Policy string `json:"policy,omitempty"` +} + +type ListWarrantsOpts struct { + // Only return warrants whose objectType matches this value. + ObjectType string `url:"object_type,omitempty"` + + // Only return warrants whose objectId matches this value. + ObjectId string `url:"object_id,omitempty"` + + // Only return warrants whose relation matches this value. + Relation string `url:"relation,omitempty"` + + // Only return warrants with a subject whose objectType matches this value. + SubjectType string `url:"subject_type,omitempty"` + + // Only return warrants with a subject whose objectId matches this value. + SubjectId string `url:"subject_id,omitempty"` + + // Only return warrants with a subject whose relation matches this value. + SubjectRelation string `url:"subject_relation,omitempty"` + + // Maximum number of records to return. + Limit int `url:"limit,omitempty"` + + // Pagination cursor to receive records after a provided Warrant ID. + After string `url:"after,omitempty"` + + // Optional token to specify desired read consistency + WarrantToken string `url:"-"` +} + +// ListWarrantsResponse describes the response structure when requesting Warrants +type ListWarrantsResponse struct { + // List of provisioned Warrants. + Data []Warrant `json:"data"` + + // Cursor pagination options. + ListMetadata common.ListMetadata `json:"list_metadata"` +} + +type WriteWarrantOpts struct { + // Operation to perform for the given warrant + Op string `json:"op,omitempty"` + + // Type of object to assign a relation to. Must be an existing type. + ObjectType string `json:"object_type"` + + // Id of the object to assign a relation to. + ObjectId string `json:"object_id"` + + // Relation to assign to the object. + Relation string `json:"relation"` + + // Subject of the warrant + Subject Subject `json:"subject"` + + // Policy that must evaluate to true for warrant to be valid + Policy string `json:"policy,omitempty"` +} + +type WriteWarrantResponse struct { + WarrantToken string `json:"warrant_token"` +} + +// Check +type Context map[string]interface{} + +type WarrantCheck struct { + // The type of the object. + ObjectType string `json:"object_type"` + + // Id of the specific object. + ObjectId string `json:"object_id"` + + // Relation to check between the object and subject. + Relation string `json:"relation"` + + // The subject that must have the specified relation. + Subject Subject `json:"subject"` + + // Contextual data to use for the access check. + Context Context `json:"context,omitempty"` +} + +type CheckOpts struct { + // Warrant to check + Warrant WarrantCheck `json:"warrant_check"` + + // Flag to include debug information in the response. + Debug bool `json:"debug,omitempty"` + + // Optional token to specify desired read consistency + WarrantToken string `json:"-"` +} + +type CheckManyOpts struct { + // The operator to use for the given warrants. + Op string `json:"op,omitempty"` + + // List of warrants to check. + Warrants []WarrantCheck `json:"warrants"` + + // Flag to include debug information in the response. + Debug bool `json:"debug,omitempty"` + + // Optional token to specify desired read consistency + WarrantToken string `json:"-"` +} + +type BatchCheckOpts struct { + // List of warrants to check. + Warrants []WarrantCheck `json:"warrants"` + + // Flag to include debug information in the response. + Debug bool `json:"debug,omitempty"` + + // Optional token to specify desired read consistency + WarrantToken string `json:"-"` +} + +type CheckResponse struct { + Code int64 `json:"code"` + Result string `json:"result"` + IsImplicit bool `json:"is_implicit"` + ProcessingTime int64 `json:"processing_time,omitempty"` + DecisionPath map[string][]Warrant `json:"decision_path,omitempty"` +} + +// Query +type QueryOpts struct { + // Query to be executed. + Query string `url:"q"` + + // Contextual data to use for the query. + Context Context `url:"context,omitempty"` + + // Maximum number of records to return. + Limit int `url:"limit,omitempty"` + + // The order in which to paginate records. + Order Order `url:"order,omitempty"` + + // Pagination cursor to receive records before a provided Warrant ID. + Before string `url:"before,omitempty"` + + // Pagination cursor to receive records after a provided Warrant ID. + After string `url:"after,omitempty"` + + // Optional token to specify desired read consistency + WarrantToken string `url:"-"` +} + +type QueryResult struct { + // The type of the object. + ObjectType string `json:"object_type"` + + // Id of the specific object. + ObjectId string `json:"object_id"` + + // Relation between the object and subject. + Relation string `json:"relation"` + + // Warrant matching the provided query + Warrant Warrant `json:"warrant"` + + // Specifies whether the warrant is implicitly defined. + IsImplicit bool `json:"is_implicit"` + + // Metadata of the object. + Meta map[string]interface{} `json:"meta,omitempty"` +} + +type QueryResponse struct { + // List of query results. + Data []QueryResult `json:"data"` + + // Cursor pagination options. + ListMetadata common.ListMetadata `json:"list_metadata"` +} + +// GetObject gets an Object. +func (c *Client) GetObject(ctx context.Context, opts GetObjectOpts) (Object, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/fga/v1/objects/%s/%s", c.Endpoint, opts.ObjectType, opts.ObjectId) + req, err := http.NewRequest(http.MethodGet, endpoint, nil) + if err != nil { + return Object{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return Object{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return Object{}, err + } + + var body Object + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// ListObjects gets a list of FGA objects. +func (c *Client) ListObjects(ctx context.Context, opts ListObjectsOpts) (ListObjectsResponse, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/fga/v1/objects", c.Endpoint) + req, err := http.NewRequest(http.MethodGet, endpoint, nil) + if err != nil { + return ListObjectsResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + if opts.Limit == 0 { + opts.Limit = ResponseLimit + } + + if opts.Order == "" { + opts.Order = Desc + } + + q, err := query.Values(opts) + if err != nil { + return ListObjectsResponse{}, err + } + + req.URL.RawQuery = q.Encode() + + res, err := c.HTTPClient.Do(req) + if err != nil { + return ListObjectsResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return ListObjectsResponse{}, err + } + + var body ListObjectsResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// CreateObject creates a new object +func (c *Client) CreateObject(ctx context.Context, opts CreateObjectOpts) (Object, error) { + c.once.Do(c.init) + + data, err := c.JSONEncode(opts) + if err != nil { + return Object{}, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/objects", c.Endpoint) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return Object{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return Object{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return Object{}, err + } + + var body Object + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// UpdateObject updates an existing Object +func (c *Client) UpdateObject(ctx context.Context, opts UpdateObjectOpts) (Object, error) { + c.once.Do(c.init) + + // UpdateObjectChangeOpts contains the options to update an Object minus the ObjectType and ObjectId + type UpdateObjectChangeOpts struct { + Meta map[string]interface{} `json:"meta"` + } + + update_opts := UpdateObjectChangeOpts{Meta: opts.Meta} + + data, err := c.JSONEncode(update_opts) + if err != nil { + return Object{}, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/objects/%s/%s", c.Endpoint, opts.ObjectType, opts.ObjectId) + req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(data)) + if err != nil { + return Object{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return Object{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return Object{}, err + } + + var body Object + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err + +} + +// DeleteObject deletes an Object +func (c *Client) DeleteObject(ctx context.Context, opts DeleteObjectOpts) error { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/fga/v1/objects/%s/%s", c.Endpoint, opts.ObjectType, opts.ObjectId) + req, err := http.NewRequest(http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + return workos_errors.TryGetHTTPError(res) +} + +// ListWarrants gets a list of Warrants. +func (c *Client) ListWarrants(ctx context.Context, opts ListWarrantsOpts) (ListWarrantsResponse, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/fga/v1/warrants", c.Endpoint) + req, err := http.NewRequest(http.MethodGet, endpoint, nil) + if err != nil { + return ListWarrantsResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + if opts.WarrantToken != "" { + req.Header.Set("Warrant-Token", opts.WarrantToken) + } + + if opts.Limit == 0 { + opts.Limit = ResponseLimit + } + + q, err := query.Values(opts) + if err != nil { + return ListWarrantsResponse{}, err + } + + req.URL.RawQuery = q.Encode() + + res, err := c.HTTPClient.Do(req) + if err != nil { + return ListWarrantsResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return ListWarrantsResponse{}, err + } + + var body ListWarrantsResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// WriteWarrant performs a write operation on a Warrant. +func (c *Client) WriteWarrant(ctx context.Context, opts WriteWarrantOpts) (WriteWarrantResponse, error) { + c.once.Do(c.init) + + data, err := c.JSONEncode(opts) + if err != nil { + return WriteWarrantResponse{}, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/warrants", c.Endpoint) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return WriteWarrantResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return WriteWarrantResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return WriteWarrantResponse{}, err + } + + var body WriteWarrantResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// BatchWriteWarrants performs a write operation on a Warrant. +func (c *Client) BatchWriteWarrants(ctx context.Context, opts []WriteWarrantOpts) (WriteWarrantResponse, error) { + c.once.Do(c.init) + + data, err := c.JSONEncode(opts) + if err != nil { + return WriteWarrantResponse{}, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/warrants", c.Endpoint) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return WriteWarrantResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return WriteWarrantResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return WriteWarrantResponse{}, err + } + + var body WriteWarrantResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +func (c *Client) Check(ctx context.Context, opts CheckOpts) (bool, error) { + return c.CheckMany(ctx, CheckManyOpts{ + Warrants: []WarrantCheck{opts.Warrant}, + Debug: opts.Debug, + WarrantToken: opts.WarrantToken, + }) +} + +func (c *Client) CheckMany(ctx context.Context, opts CheckManyOpts) (bool, error) { + c.once.Do(c.init) + + data, err := c.JSONEncode(opts) + if err != nil { + return false, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/check", c.Endpoint) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return false, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + if opts.WarrantToken != "" { + req.Header.Set("Warrant-Token", opts.WarrantToken) + } + + res, err := c.HTTPClient.Do(req) + if err != nil { + return false, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return false, err + } + + var checkResponse CheckResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&checkResponse) + if err != nil { + return false, err + } + + return checkResponse.Result == "Authorized", nil +} + +func (c *Client) BatchCheck(ctx context.Context, opts BatchCheckOpts) ([]bool, error) { + c.once.Do(c.init) + + checkOpts := CheckManyOpts{ + Op: "batch", + Warrants: opts.Warrants, + Debug: opts.Debug, + WarrantToken: opts.WarrantToken, + } + data, err := c.JSONEncode(checkOpts) + if err != nil { + return []bool{}, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/check", c.Endpoint) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return []bool{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + if opts.WarrantToken != "" { + req.Header.Set("Warrant-Token", opts.WarrantToken) + } + + res, err := c.HTTPClient.Do(req) + if err != nil { + return []bool{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return []bool{}, err + } + + var checkResponses []CheckResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&checkResponses) + if err != nil { + return []bool{}, err + } + + var results []bool + for _, checkResponse := range checkResponses { + results = append(results, checkResponse.Result == "Authorized") + } + return results, nil +} + +// Query executes a query for a set of resources. +func (c *Client) Query(ctx context.Context, opts QueryOpts) (QueryResponse, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/fga/v1/query", c.Endpoint) + req, err := http.NewRequest(http.MethodGet, endpoint, nil) + if err != nil { + return QueryResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + if opts.WarrantToken != "" { + req.Header.Set("Warrant-Token", opts.WarrantToken) + } + + if opts.Limit == 0 { + opts.Limit = ResponseLimit + } + + if opts.Order == "" { + opts.Order = Desc + } + + type QueryUrlOpts struct { + Query string `url:"q"` + Context string `url:"context,omitempty"` + Limit int `url:"limit,omitempty"` + Order Order `url:"order,omitempty"` + Before string `url:"before,omitempty"` + After string `url:"after,omitempty"` + WarrantToken string `url:"-"` + } + + var jsonCtx []byte + if opts.Context != nil { + jsonCtx, err = json.Marshal(opts.Context) + if err != nil { + return QueryResponse{}, err + } + } + queryUrlOpts := QueryUrlOpts{ + Query: opts.Query, + Context: string(jsonCtx), + Limit: opts.Limit, + Order: opts.Order, + Before: opts.Before, + After: opts.After, + } + + q, err := query.Values(queryUrlOpts) + if err != nil { + return QueryResponse{}, err + } + + req.URL.RawQuery = q.Encode() + + res, err := c.HTTPClient.Do(req) + if err != nil { + return QueryResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return QueryResponse{}, err + } + + var body QueryResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} diff --git a/pkg/fga/client_live_example.go b/pkg/fga/client_live_example.go new file mode 100644 index 00000000..230f32ef --- /dev/null +++ b/pkg/fga/client_live_example.go @@ -0,0 +1,1860 @@ +package fga + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func setup() { + SetAPIKey("") +} + +func TestCrudObjects(t *testing.T) { + setup() + + object1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "document", + }) + if err != nil { + t.Fatal(err) + } + require.Equal(t, "document", object1.ObjectType) + require.NotEmpty(t, object1.ObjectId) + require.Empty(t, object1.Meta) + + object2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "folder", + ObjectId: "planning", + }) + if err != nil { + t.Fatal(err) + } + refetchedObject, err := GetObject(context.Background(), GetObjectOpts{ + ObjectType: object2.ObjectType, + ObjectId: object2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + require.Equal(t, object2.ObjectType, refetchedObject.ObjectType) + require.Equal(t, object2.ObjectId, refetchedObject.ObjectId) + require.EqualValues(t, object2.Meta, refetchedObject.Meta) + + object2, err = UpdateObject(context.Background(), UpdateObjectOpts{ + ObjectType: object2.ObjectType, + ObjectId: object2.ObjectId, + Meta: map[string]interface{}{ + "description": "Folder object", + }, + }) + if err != nil { + t.Fatal(err) + } + refetchedObject, err = GetObject(context.Background(), GetObjectOpts{ + ObjectType: object2.ObjectType, + ObjectId: object2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + require.Equal(t, object2.ObjectType, refetchedObject.ObjectType) + require.Equal(t, object2.ObjectId, refetchedObject.ObjectId) + require.EqualValues(t, object2.Meta, refetchedObject.Meta) + + objectsList, err := ListObjects(context.Background(), ListObjectsOpts{ + Limit: 10, + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, objectsList.Data, 2) + require.Equal(t, object2.ObjectType, objectsList.Data[0].ObjectType) + require.Equal(t, object2.ObjectId, objectsList.Data[0].ObjectId) + require.Equal(t, object1.ObjectType, objectsList.Data[1].ObjectType) + require.Equal(t, object1.ObjectId, objectsList.Data[1].ObjectId) + + // Sort in ascending order + objectsList, err = ListObjects(context.Background(), ListObjectsOpts{ + Limit: 10, + Order: Asc, + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, objectsList.Data, 2) + require.Equal(t, object1.ObjectType, objectsList.Data[0].ObjectType) + require.Equal(t, object1.ObjectId, objectsList.Data[0].ObjectId) + require.Equal(t, object2.ObjectType, objectsList.Data[1].ObjectType) + require.Equal(t, object2.ObjectId, objectsList.Data[1].ObjectId) + + objectsList, err = ListObjects(context.Background(), ListObjectsOpts{ + Limit: 10, + Search: "planning", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, objectsList.Data, 1) + require.Equal(t, object2.ObjectType, objectsList.Data[0].ObjectType) + require.Equal(t, object2.ObjectId, objectsList.Data[0].ObjectId) + + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: object1.ObjectType, + ObjectId: object1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: object2.ObjectType, + ObjectId: object2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + objectsList, err = ListObjects(context.Background(), ListObjectsOpts{ + Limit: 10, + Search: "planning", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, objectsList.Data, 0) +} + +func TestMultiTenancy(t *testing.T) { + setup() + + // Create users + user1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + user2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + + // Create tenants + tenant1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "tenant", + ObjectId: "tenant-1", + Meta: map[string]interface{}{ + "name": "Tenant 1", + }, + }) + if err != nil { + t.Fatal(err) + } + tenant2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "tenant", + ObjectId: "tenant-2", + Meta: map[string]interface{}{ + "name": "Tenant 2", + }, + }) + if err != nil { + t.Fatal(err) + } + + user1TenantsList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select tenant where user:%s is member", user1.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, user1TenantsList.Data, 0) + tenant1UsersList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select member of type user for tenant:%s", tenant1.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, tenant1UsersList.Data, 0) + + // Assign user1 -> tenant1 + warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: tenant1.ObjectType, + ObjectId: tenant1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + user1TenantsList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select tenant where user:%s is member", user1.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, user1TenantsList.Data, 1) + require.Equal(t, "tenant", user1TenantsList.Data[0].ObjectType) + require.Equal(t, "tenant-1", user1TenantsList.Data[0].ObjectId) + require.EqualValues(t, map[string]interface{}{ + "name": "Tenant 1", + }, user1TenantsList.Data[0].Meta) + + tenant1UsersList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select member of type user for tenant:%s", tenant1.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, tenant1UsersList.Data, 1) + require.Equal(t, "user", tenant1UsersList.Data[0].ObjectType) + require.Equal(t, user1.ObjectId, tenant1UsersList.Data[0].ObjectId) + require.Empty(t, tenant1UsersList.Data[0].Meta) + + // Remove user1 -> tenant1 + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: tenant1.ObjectType, + ObjectId: tenant1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + user1TenantsList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select tenant where user:%s is member", user1.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, user1TenantsList.Data, 0) + tenant1UsersList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select member of type user for tenant:%s", tenant1.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, tenant1UsersList.Data, 0) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: user2.ObjectType, + ObjectId: user2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: tenant1.ObjectType, + ObjectId: tenant1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: tenant2.ObjectType, + ObjectId: tenant2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestRBAC(t *testing.T) { + setup() + + // Create users + adminUser, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + viewerUser, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + + // Create roles + adminRole, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "role", + ObjectId: "administrator", + Meta: map[string]interface{}{ + "name": "Administrator", + "description": "The admin role", + }, + }) + if err != nil { + t.Fatal(err) + } + viewerRole, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "role", + ObjectId: "viewer", + Meta: map[string]interface{}{ + "name": "Viewer", + "description": "The viewer role", + }, + }) + if err != nil { + t.Fatal(err) + } + + // Create permissions + createPermission, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "create-report", + Meta: map[string]interface{}{ + "name": "Create Report", + "description": "Permission to create reports", + }, + }) + if err != nil { + t.Fatal(err) + } + viewPermission, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "view-report", + Meta: map[string]interface{}{ + "name": "View Report", + "description": "Permission to view reports", + }, + }) + if err != nil { + t.Fatal(err) + } + + adminUserRolesList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select role where user:%s is member", adminUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, adminUserRolesList.Data, 0) + + adminRolePermissionsList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where role:%s is member", adminRole.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, adminRolePermissionsList.Data, 0) + + adminUserHasPermission, err := Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: createPermission.ObjectType, + ObjectId: createPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminUser.ObjectType, + ObjectId: adminUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, adminUserHasPermission) + + // Assign create-report permission -> admin role -> admin user + warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: createPermission.ObjectType, + ObjectId: createPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminRole.ObjectType, + ObjectId: adminRole.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: adminRole.ObjectType, + ObjectId: adminRole.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminUser.ObjectType, + ObjectId: adminUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + adminUserHasPermission, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: createPermission.ObjectType, + ObjectId: createPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminUser.ObjectType, + ObjectId: adminUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.True(t, adminUserHasPermission) + + adminUserRolesList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select role where user:%s is member", adminUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, adminUserRolesList.Data, 1) + require.Equal(t, "role", adminUserRolesList.Data[0].ObjectType) + require.Equal(t, adminRole.ObjectId, adminUserRolesList.Data[0].ObjectId) + require.Equal(t, map[string]interface{}{ + "name": "Administrator", + "description": "The admin role", + }, adminUserRolesList.Data[0].Meta) + + adminRolePermissionsList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where role:%s is member", adminRole.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, adminRolePermissionsList.Data, 1) + require.Equal(t, "permission", adminRolePermissionsList.Data[0].ObjectType) + require.Equal(t, createPermission.ObjectId, adminRolePermissionsList.Data[0].ObjectId) + require.Equal(t, map[string]interface{}{ + "name": "Create Report", + "description": "Permission to create reports", + }, adminRolePermissionsList.Data[0].Meta) + + // Remove create-report permission -> admin role -> admin user + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: createPermission.ObjectType, + ObjectId: createPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminRole.ObjectType, + ObjectId: adminRole.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: adminRole.ObjectType, + ObjectId: adminRole.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminUser.ObjectType, + ObjectId: adminUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + adminUserHasPermission, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: createPermission.ObjectType, + ObjectId: createPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: adminUser.ObjectType, + ObjectId: adminUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, adminUserHasPermission) + + adminUserRolesList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select role where user:%s is member", adminUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, adminUserRolesList.Data, 0) + + adminRolePermissionsList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where role:%s is member", adminRole.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, adminRolePermissionsList.Data, 0) + + // Assign view-report -> viewer user + viewerUserHasPermission, err := Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: viewPermission.ObjectType, + ObjectId: viewPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: viewerUser.ObjectType, + ObjectId: viewerUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, viewerUserHasPermission) + + viewerUserPermissionsList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where user:%s is member", viewerUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, viewerUserPermissionsList.Data) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: viewPermission.ObjectType, + ObjectId: viewPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: viewerUser.ObjectType, + ObjectId: viewerUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + viewerUserHasPermission, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: viewPermission.ObjectType, + ObjectId: viewPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: viewerUser.ObjectType, + ObjectId: viewerUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.True(t, viewerUserHasPermission) + + viewerUserPermissionsList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where user:%s is member", viewerUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, viewerUserPermissionsList.Data, 1) + require.Equal(t, "permission", viewerUserPermissionsList.Data[0].ObjectType) + require.Equal(t, viewPermission.ObjectId, viewerUserPermissionsList.Data[0].ObjectId) + require.Equal(t, map[string]interface{}{ + "name": "View Report", + "description": "Permission to view reports", + }, viewerUserPermissionsList.Data[0].Meta) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: viewPermission.ObjectType, + ObjectId: viewPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: viewerUser.ObjectType, + ObjectId: viewerUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + viewerUserHasPermission, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: viewPermission.ObjectType, + ObjectId: viewPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: viewerUser.ObjectType, + ObjectId: viewerUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, viewerUserHasPermission) + + viewerUserPermissionsList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where user:%s is member", viewerUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, viewerUserPermissionsList.Data) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: adminUser.ObjectType, + ObjectId: adminUser.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: viewerUser.ObjectType, + ObjectId: viewerUser.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: adminRole.ObjectType, + ObjectId: adminRole.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: viewerRole.ObjectType, + ObjectId: viewerRole.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: createPermission.ObjectType, + ObjectId: createPermission.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: viewPermission.ObjectType, + ObjectId: viewPermission.ObjectId, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestPricingTiersFeaturesAndUsers(t *testing.T) { + setup() + + // Create users + freeUser, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + paidUser, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + + // Create pricing tiers + freeTier, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "pricing-tier", + ObjectId: "free", + Meta: map[string]interface{}{ + "name": "Free Tier", + }, + }) + if err != nil { + t.Fatal(err) + } + paidTier, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "pricing-tier", + ObjectId: "paid", + }) + if err != nil { + t.Fatal(err) + } + + // Create features + customFeature, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "feature", + ObjectId: "custom", + Meta: map[string]interface{}{ + "name": "Custom Feature", + }, + }) + if err != nil { + t.Fatal(err) + } + feature1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "feature", + ObjectId: "feature-1", + }) + if err != nil { + t.Fatal(err) + } + feature2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "feature", + ObjectId: "feature-2", + }) + if err != nil { + t.Fatal(err) + } + + // Assign custom-feature -> paid user + paidUserHasFeature, err := Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: customFeature.ObjectType, + ObjectId: customFeature.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: paidUser.ObjectType, + ObjectId: paidUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, paidUserHasFeature) + + paidUserFeaturesList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select feature where user:%s is member", paidUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, paidUserFeaturesList.Data) + + warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: customFeature.ObjectType, + ObjectId: customFeature.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: paidUser.ObjectType, + ObjectId: paidUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + paidUserHasFeature, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: customFeature.ObjectType, + ObjectId: customFeature.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: paidUser.ObjectType, + ObjectId: paidUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.True(t, paidUserHasFeature) + + paidUserFeaturesList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select feature where user:%s is member", paidUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, paidUserFeaturesList.Data, 1) + require.Equal(t, "feature", paidUserFeaturesList.Data[0].ObjectType) + require.Equal(t, customFeature.ObjectId, paidUserFeaturesList.Data[0].ObjectId) + require.Equal(t, map[string]interface{}{ + "name": "Custom Feature", + }, paidUserFeaturesList.Data[0].Meta) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: customFeature.ObjectType, + ObjectId: customFeature.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: paidUser.ObjectType, + ObjectId: paidUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + paidUserHasFeature, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: customFeature.ObjectType, + ObjectId: customFeature.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: paidUser.ObjectType, + ObjectId: paidUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, paidUserHasFeature) + + paidUserFeaturesList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select feature where user:%s is member", paidUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, paidUserFeaturesList.Data) + + // Assign feature-1 -> free tier -> free user + freeUserHasFeature, err := Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: feature1.ObjectType, + ObjectId: feature1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeUser.ObjectType, + ObjectId: freeUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, freeUserHasFeature) + + freeUserFeaturesList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select feature where user:%s is member", freeUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, freeUserFeaturesList.Data) + + featureUserTiersList, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select pricing-tier where user:%s is member", freeUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, featureUserTiersList.Data) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: feature1.ObjectType, + ObjectId: feature1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeTier.ObjectType, + ObjectId: freeTier.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: freeTier.ObjectType, + ObjectId: freeTier.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeUser.ObjectType, + ObjectId: freeUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + freeUserHasFeature, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: feature1.ObjectType, + ObjectId: feature1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeUser.ObjectType, + ObjectId: freeUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.True(t, freeUserHasFeature) + + freeUserFeaturesList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select feature where user:%s is member", freeUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, freeUserFeaturesList.Data, 1) + require.Equal(t, "feature", freeUserFeaturesList.Data[0].ObjectType) + require.Equal(t, feature1.ObjectId, freeUserFeaturesList.Data[0].ObjectId) + require.Empty(t, freeUserFeaturesList.Data[0].Meta) + + featureUserTiersList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select pricing-tier where user:%s is member", freeUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, featureUserTiersList.Data, 1) + require.Equal(t, "pricing-tier", featureUserTiersList.Data[0].ObjectType) + require.Equal(t, freeTier.ObjectId, featureUserTiersList.Data[0].ObjectId) + require.Equal(t, map[string]interface{}{ + "name": "Free Tier", + }, featureUserTiersList.Data[0].Meta) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: feature1.ObjectType, + ObjectId: feature1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeTier.ObjectType, + ObjectId: freeTier.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: freeTier.ObjectType, + ObjectId: freeTier.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeUser.ObjectType, + ObjectId: freeUser.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + freeUserHasFeature, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: feature1.ObjectType, + ObjectId: feature1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: freeUser.ObjectType, + ObjectId: freeUser.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, freeUserHasFeature) + + freeUserFeaturesList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select feature where user:%s is member", freeUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, freeUserFeaturesList.Data) + + featureUserTiersList, err = Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select pricing-tier where user:%s is member", freeUser.ObjectId), + Limit: 10, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Empty(t, featureUserTiersList.Data) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: freeUser.ObjectType, + ObjectId: freeUser.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: paidUser.ObjectType, + ObjectId: paidUser.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: freeTier.ObjectType, + ObjectId: freeTier.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: paidTier.ObjectType, + ObjectId: paidTier.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: customFeature.ObjectType, + ObjectId: customFeature.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: feature1.ObjectType, + ObjectId: feature1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: feature2.ObjectType, + ObjectId: feature2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestWarrants(t *testing.T) { + setup() + + user1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + ObjectId: "userA", + }) + if err != nil { + t.Fatal(err) + } + user2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + ObjectId: "userB", + }) + if err != nil { + t.Fatal(err) + } + newPermission, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "perm1", + Meta: map[string]interface{}{ + "name": "Permission 1", + "description": "Permission 1", + }, + }) + if err != nil { + t.Fatal(err) + } + + userHasPermission, err := Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, userHasPermission) + + warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user2.ObjectType, + ObjectId: user2.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + warrants1, err := ListWarrants(context.Background(), ListWarrantsOpts{ + Limit: 1, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, warrants1.Data, 1) + require.Equal(t, newPermission.ObjectType, warrants1.Data[0].ObjectType) + require.Equal(t, newPermission.ObjectId, warrants1.Data[0].ObjectId) + require.Equal(t, "member", warrants1.Data[0].Relation) + require.Equal(t, user2.ObjectType, warrants1.Data[0].Subject.ObjectType) + require.Equal(t, user2.ObjectId, warrants1.Data[0].Subject.ObjectId) + + warrants2, err := ListWarrants(context.Background(), ListWarrantsOpts{ + Limit: 1, + After: warrants1.ListMetadata.After, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, warrants2.Data, 1) + require.Equal(t, newPermission.ObjectType, warrants2.Data[0].ObjectType) + require.Equal(t, newPermission.ObjectId, warrants2.Data[0].ObjectId) + require.Equal(t, "member", warrants2.Data[0].Relation) + require.Equal(t, user1.ObjectType, warrants2.Data[0].Subject.ObjectType) + require.Equal(t, user1.ObjectId, warrants2.Data[0].Subject.ObjectId) + + warrants3, err := ListWarrants(context.Background(), ListWarrantsOpts{ + SubjectType: "user", + SubjectId: user1.ObjectId, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, warrants3.Data, 1) + require.Equal(t, newPermission.ObjectType, warrants3.Data[0].ObjectType) + require.Equal(t, newPermission.ObjectId, warrants3.Data[0].ObjectId) + require.Equal(t, "member", warrants3.Data[0].Relation) + require.Equal(t, user1.ObjectType, warrants3.Data[0].Subject.ObjectType) + require.Equal(t, user1.ObjectId, warrants3.Data[0].Subject.ObjectId) + + userHasPermission, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.True(t, userHasPermission) + + queryResponse, err := Query(context.Background(), QueryOpts{ + Query: fmt.Sprintf("select permission where user:%s is member", user1.ObjectId), + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, queryResponse.Data, 1) + require.Equal(t, newPermission.ObjectType, queryResponse.Data[0].ObjectType) + require.Equal(t, newPermission.ObjectId, queryResponse.Data[0].ObjectId) + require.Equal(t, "member", queryResponse.Data[0].Relation) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + userHasPermission, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, userHasPermission) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: user1.ObjectType, + ObjectId: user1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: user2.ObjectType, + ObjectId: user2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: newPermission.ObjectType, + ObjectId: newPermission.ObjectId, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestBatchWarrants(t *testing.T) { + setup() + + newUser, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + }) + if err != nil { + t.Fatal(err) + } + permission1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "perm1", + Meta: map[string]interface{}{ + "name": "Permission 1", + "description": "Permission 1", + }, + }) + if err != nil { + t.Fatal(err) + } + permission2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "perm2", + Meta: map[string]interface{}{ + "name": "Permission 2", + "description": "Permission 2", + }, + }) + if err != nil { + t.Fatal(err) + } + + userHasPermissions, err := BatchCheck(context.Background(), BatchCheckOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + { + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, userHasPermissions, 2) + require.False(t, userHasPermissions[0]) + require.False(t, userHasPermissions[1]) + + warrantResponse, err := BatchWriteWarrants(context.Background(), []WriteWarrantOpts{ + { + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + { + Op: "create", + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + userHasPermissions, err = BatchCheck(context.Background(), BatchCheckOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + { + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, userHasPermissions, 2) + require.True(t, userHasPermissions[0]) + require.True(t, userHasPermissions[1]) + + warrantResponse, err = BatchWriteWarrants(context.Background(), []WriteWarrantOpts{ + { + Op: "delete", + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + { + Op: "delete", + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + userHasPermissions, err = BatchCheck(context.Background(), BatchCheckOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + { + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }, + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, userHasPermissions, 2) + require.False(t, userHasPermissions[0]) + require.False(t, userHasPermissions[1]) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: newUser.ObjectType, + ObjectId: newUser.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestWarrantsWithPolicy(t *testing.T) { + setup() + + warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ + ObjectType: "permission", + ObjectId: "test-permission", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user-1", + }, + Policy: `geo == "us"`, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + checkResult, err := Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: "permission", + ObjectId: "test-permission", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user-1", + }, + Context: map[string]interface{}{ + "geo": "us", + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.True(t, checkResult) + + checkResult, err = Check(context.Background(), CheckOpts{ + Warrant: WarrantCheck{ + ObjectType: "permission", + ObjectId: "test-permission", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user-1", + }, + Context: map[string]interface{}{ + "geo": "eu", + }, + }, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.False(t, checkResult) + + warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "delete", + ObjectType: "permission", + ObjectId: "test-permission", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user-1", + }, + Policy: `geo == "us"`, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: "permission", + ObjectId: "test-permission", + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: "user", + ObjectId: "user-1", + }) + if err != nil { + t.Fatal(err) + } +} + +func TestQueryWarrants(t *testing.T) { + setup() + + userA, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + ObjectId: "userA", + }) + if err != nil { + t.Fatal(err) + } + userB, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "user", + ObjectId: "userB", + }) + if err != nil { + t.Fatal(err) + } + permission1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "perm1", + Meta: map[string]interface{}{ + "name": "Permission 1", + "description": "This is permission 1.", + }, + }) + if err != nil { + t.Fatal(err) + } + permission2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "perm2", + }) + if err != nil { + t.Fatal(err) + } + permission3, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "permission", + ObjectId: "perm3", + Meta: map[string]interface{}{ + "name": "Permission 3", + "description": "This is permission 3.", + }, + }) + if err != nil { + t.Fatal(err) + } + role1, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "role", + ObjectId: "role1", + Meta: map[string]interface{}{ + "name": "Role 1", + "description": "This is role 1.", + }, + }) + if err != nil { + t.Fatal(err) + } + role2, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "role", + ObjectId: "role2", + Meta: map[string]interface{}{ + "name": "Role 2", + }, + }) + if err != nil { + t.Fatal(err) + } + + warrantResponse, err := BatchWriteWarrants(context.Background(), []WriteWarrantOpts{ + { + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: role1.ObjectType, + ObjectId: role1.ObjectId, + }, + }, + { + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: role2.ObjectType, + ObjectId: role2.ObjectId, + }, + }, + { + ObjectType: permission3.ObjectType, + ObjectId: permission3.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: role2.ObjectType, + ObjectId: role2.ObjectId, + }, + }, + { + ObjectType: role2.ObjectType, + ObjectId: role2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: role1.ObjectType, + ObjectId: role1.ObjectId, + }, + }, + { + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: role2.ObjectType, + ObjectId: role2.ObjectId, + }, + Policy: "tenantId == 123", + }, + { + ObjectType: role1.ObjectType, + ObjectId: role1.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: userA.ObjectType, + ObjectId: userA.ObjectId, + }, + }, + { + ObjectType: role2.ObjectType, + ObjectId: role2.ObjectId, + Relation: "member", + Subject: Subject{ + ObjectType: userB.ObjectType, + ObjectId: userB.ObjectId, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + require.NotEmpty(t, warrantResponse.WarrantToken) + + queryResponse, err := Query(context.Background(), QueryOpts{ + Query: "select role where user:userA is member", + Limit: 1, + Order: Asc, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, queryResponse.Data, 1) + require.Equal(t, role1.ObjectType, queryResponse.Data[0].ObjectType) + require.Equal(t, role1.ObjectId, queryResponse.Data[0].ObjectId) + require.Equal(t, "member", queryResponse.Data[0].Relation) + require.Equal(t, role1.ObjectType, queryResponse.Data[0].Warrant.ObjectType) + require.Equal(t, role1.ObjectId, queryResponse.Data[0].Warrant.ObjectId) + require.Equal(t, "member", queryResponse.Data[0].Warrant.Relation) + require.Equal(t, userA.ObjectType, queryResponse.Data[0].Warrant.Subject.ObjectType) + require.Equal(t, userA.ObjectId, queryResponse.Data[0].Warrant.Subject.ObjectId) + require.Empty(t, queryResponse.Data[0].Warrant.Policy) + require.False(t, queryResponse.Data[0].IsImplicit) + + queryResponse, err = Query(context.Background(), QueryOpts{ + Query: "select role where user:userA is member", + Limit: 1, + Order: Asc, + After: queryResponse.ListMetadata.After, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, queryResponse.Data, 1) + require.Equal(t, role2.ObjectType, queryResponse.Data[0].ObjectType) + require.Equal(t, role2.ObjectId, queryResponse.Data[0].ObjectId) + require.Equal(t, "member", queryResponse.Data[0].Relation) + require.Equal(t, role2.ObjectType, queryResponse.Data[0].Warrant.ObjectType) + require.Equal(t, role2.ObjectId, queryResponse.Data[0].Warrant.ObjectId) + require.Equal(t, "member", queryResponse.Data[0].Warrant.Relation) + require.Equal(t, role1.ObjectType, queryResponse.Data[0].Warrant.Subject.ObjectType) + require.Equal(t, role1.ObjectId, queryResponse.Data[0].Warrant.Subject.ObjectId) + require.Empty(t, queryResponse.Data[0].Warrant.Policy) + require.True(t, queryResponse.Data[0].IsImplicit) + + queryResponse, err = Query(context.Background(), QueryOpts{ + Query: "select permission where user:userB is member", + Context: Context{ + "tenantId": 123, + }, + Order: Asc, + WarrantToken: "latest", + }) + if err != nil { + t.Fatal(err) + } + require.Len(t, queryResponse.Data, 3) + require.Equal(t, permission1.ObjectType, queryResponse.Data[0].ObjectType) + require.Equal(t, permission1.ObjectId, queryResponse.Data[0].ObjectId) + require.Equal(t, "member", queryResponse.Data[0].Relation) + require.Equal(t, permission2.ObjectType, queryResponse.Data[1].ObjectType) + require.Equal(t, permission2.ObjectId, queryResponse.Data[1].ObjectId) + require.Equal(t, "member", queryResponse.Data[1].Relation) + require.Equal(t, permission3.ObjectType, queryResponse.Data[2].ObjectType) + require.Equal(t, permission3.ObjectId, queryResponse.Data[2].ObjectId) + require.Equal(t, "member", queryResponse.Data[2].Relation) + + // Clean up + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: role1.ObjectType, + ObjectId: role1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: role2.ObjectType, + ObjectId: role2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: permission1.ObjectType, + ObjectId: permission1.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: permission2.ObjectType, + ObjectId: permission2.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: permission3.ObjectType, + ObjectId: permission3.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: userA.ObjectType, + ObjectId: userA.ObjectId, + }) + if err != nil { + t.Fatal(err) + } + err = DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: userB.ObjectType, + ObjectId: userB.ObjectId, + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/pkg/fga/client_test.go b/pkg/fga/client_test.go new file mode 100644 index 00000000..3b496f94 --- /dev/null +++ b/pkg/fga/client_test.go @@ -0,0 +1,1073 @@ +package fga + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/workos/workos-go/v4/pkg/common" +) + +func TestGetObject(t *testing.T) { + tests := []struct { + scenario string + client *Client + options GetObjectOpts + expected Object + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns an Object", + client: &Client{ + APIKey: "test", + }, + options: GetObjectOpts{ + ObjectType: "report", + ObjectId: "ljc_1029", + }, + expected: Object{ + ObjectType: "report", + ObjectId: "ljc_1029", + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(getObjectTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + object, err := client.GetObject(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, object) + }) + } +} + +func getObjectTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + body, err := json.Marshal(Object{ + ObjectType: "report", + ObjectId: "ljc_1029", + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestListObjects(t *testing.T) { + tests := []struct { + scenario string + client *Client + options ListObjectsOpts + expected ListObjectsResponse + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns Objects", + client: &Client{ + APIKey: "test", + }, + options: ListObjectsOpts{ + ObjectType: "report", + }, + + expected: ListObjectsResponse{ + Data: []Object{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + }, + { + ObjectType: "report", + ObjectId: "mso_0806", + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(listObjectsTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + objects, err := client.ListObjects(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, objects) + }) + } +} + +func listObjectsTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal(struct { + ListObjectsResponse + }{ + ListObjectsResponse: ListObjectsResponse{ + Data: []Object{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + }, + { + ObjectType: "report", + ObjectId: "mso_0806", + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestCreateObject(t *testing.T) { + tests := []struct { + scenario string + client *Client + options CreateObjectOpts + expected Object + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns Object", + client: &Client{ + APIKey: "test", + }, + options: CreateObjectOpts{ + ObjectType: "report", + ObjectId: "sso_1710", + }, + expected: Object{ + ObjectType: "report", + ObjectId: "sso_1710", + }, + }, + { + scenario: "Request returns Object with Metadata", + client: &Client{ + APIKey: "test", + }, + options: CreateObjectOpts{ + ObjectType: "report", + ObjectId: "sso_1710", + Meta: map[string]interface{}{ + "description": "Some report", + }, + }, + expected: Object{ + ObjectType: "report", + ObjectId: "sso_1710", + Meta: map[string]interface{}{ + "description": "Some report", + }, + }, + }, + { + scenario: "Request with no ObjectId returns an Object with generated report", + client: &Client{ + APIKey: "test", + }, + options: CreateObjectOpts{ + ObjectType: "report", + }, + expected: Object{ + ObjectType: "report", + ObjectId: "report_1029384756", + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(createObjectTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + object, err := client.CreateObject(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, object) + }) + } +} + +func createObjectTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + var opts CreateObjectOpts + json.NewDecoder(r.Body).Decode(&opts) + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + objectId := "sso_1710" + if opts.ObjectId == "" { + objectId = "report_1029384756" + } + + body, err := json.Marshal( + Object{ + ObjectType: "report", + ObjectId: objectId, + Meta: opts.Meta, + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestUpdateObject(t *testing.T) { + tests := []struct { + scenario string + client *Client + options UpdateObjectOpts + expected Object + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns Object with updated Meta", + client: &Client{ + APIKey: "test", + }, + options: UpdateObjectOpts{ + ObjectType: "report", + ObjectId: "lad_8812", + Meta: map[string]interface{}{ + "description": "Updated report", + }, + }, + expected: Object{ + ObjectType: "report", + ObjectId: "lad_8812", + Meta: map[string]interface{}{ + "description": "Updated report", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(updateObjectTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + object, err := client.UpdateObject(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, object) + }) + } +} + +func updateObjectTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal( + Object{ + ObjectType: "report", + ObjectId: "lad_8812", + Meta: map[string]interface{}{ + "description": "Updated report", + }, + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestDeleteObject(t *testing.T) { + tests := []struct { + scenario string + client *Client + options DeleteObjectOpts + expected error + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns Object", + client: &Client{ + APIKey: "test", + }, + options: DeleteObjectOpts{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + expected: nil, + }, + { + scenario: "Request for non-existent Object returns error", + client: &Client{ + APIKey: "test", + }, + err: true, + options: DeleteObjectOpts{ + ObjectType: "user", + ObjectId: "safgdfgs", + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(deleteObjectTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + err := client.DeleteObject(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, err) + }) + } +} + +func deleteObjectTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + var opts CreateObjectOpts + json.NewDecoder(r.Body).Decode(&opts) + + var body []byte + var err error + + if r.URL.Path == "/fga/v1/objects/user/user_01SXW182" { + body, err = nil, nil + } else { + http.Error(w, fmt.Sprintf("%s %s not found", opts.ObjectType, opts.ObjectId), http.StatusNotFound) + return + } + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestListWarrants(t *testing.T) { + tests := []struct { + scenario string + client *Client + options ListWarrantsOpts + expected ListWarrantsResponse + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns Warrants", + client: &Client{ + APIKey: "test", + }, + options: ListWarrantsOpts{ + ObjectType: "report", + }, + + expected: ListWarrantsResponse{ + Data: []Warrant{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + { + ObjectType: "report", + ObjectId: "aut_7403", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(listWarrantsTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + objects, err := client.ListWarrants(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, objects) + }) + } +} + +func listWarrantsTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal(struct { + ListWarrantsResponse + }{ + ListWarrantsResponse: ListWarrantsResponse{ + Data: []Warrant{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + { + ObjectType: "report", + ObjectId: "aut_7403", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestWriteWarrant(t *testing.T) { + tests := []struct { + scenario string + client *Client + options WriteWarrantOpts + expected WriteWarrantResponse + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request with no op returns WarrantToken", + client: &Client{ + APIKey: "test", + }, + options: WriteWarrantOpts{ + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + expected: WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + }, + }, + { + scenario: "Request with create op returns WarrantToken", + client: &Client{ + APIKey: "test", + }, + options: WriteWarrantOpts{ + Op: "create", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + expected: WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + }, + }, + { + scenario: "Request with delete op returns WarrantToken", + client: &Client{ + APIKey: "test", + }, + options: WriteWarrantOpts{ + Op: "delete", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + expected: WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(writeWarrantTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + warrantResponse, err := client.WriteWarrant(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, warrantResponse) + }) + } +} + +func TestBatchWriteWarrants(t *testing.T) { + tests := []struct { + scenario string + client *Client + options []WriteWarrantOpts + expected WriteWarrantResponse + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request with multiple warrants returns WarrantToken", + client: &Client{ + APIKey: "test", + }, + options: []WriteWarrantOpts{ + { + Op: "delete", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "viewer", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + { + Op: "create", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "editor", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + expected: WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(writeWarrantTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + warrantResponse, err := client.BatchWriteWarrants(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, warrantResponse) + }) + } +} + +func writeWarrantTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal( + WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestCheckMany(t *testing.T) { + tests := []struct { + scenario string + client *Client + options CheckManyOpts + expected bool + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns true check result", + client: &Client{ + APIKey: "test", + }, + options: CheckManyOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }, + expected: true, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(checkManyTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + checkResult, err := client.CheckMany(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, checkResult) + }) + } +} + +func checkManyTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal( + CheckResponse{ + Code: 200, + Result: "Authorized", + IsImplicit: false, + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestBatchCheck(t *testing.T) { + tests := []struct { + scenario string + client *Client + options BatchCheckOpts + expected []bool + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns array of check results", + client: &Client{ + APIKey: "test", + }, + options: BatchCheckOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + { + ObjectType: "report", + ObjectId: "spt_8521", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }, + expected: []bool{true, false}, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(batchCheckTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + checkResult, err := client.BatchCheck(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, checkResult) + }) + } +} + +func batchCheckTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal( + []CheckResponse{ + { + Code: 200, + Result: "Authorized", + IsImplicit: false, + }, + { + Code: 401, + Result: "Not Authorized", + IsImplicit: false, + }, + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestQuery(t *testing.T) { + tests := []struct { + scenario string + client *Client + options QueryOpts + expected QueryResponse + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns QueryResults", + client: &Client{ + APIKey: "test", + }, + options: QueryOpts{ + Query: "select role where user:user_01SXW182 is member", + }, + expected: QueryResponse{ + Data: []QueryResult{ + { + ObjectType: "role", + ObjectId: "role_01SXW182", + Relation: "member", + Warrant: Warrant{ + ObjectType: "role", + ObjectId: "role_01SXW182", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(queryTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + queryResults, err := client.Query(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, queryResults) + }) + } +} + +func queryTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal(struct { + QueryResponse + }{ + QueryResponse: QueryResponse{ + Data: []QueryResult{ + { + ObjectType: "role", + ObjectId: "role_01SXW182", + Relation: "member", + Warrant: Warrant{ + ObjectType: "role", + ObjectId: "role_01SXW182", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} diff --git a/pkg/fga/fga.go b/pkg/fga/fga.go new file mode 100644 index 00000000..787e0cfc --- /dev/null +++ b/pkg/fga/fga.go @@ -0,0 +1,111 @@ +package fga + +import "context" + +// DefaultClient is the client used by SetAPIKey and FGA functions. +var ( + DefaultClient = &Client{ + Endpoint: "https://api.workos.com", + } +) + +// SetAPIKey sets the WorkOS API key for FGA requests. +func SetAPIKey(apiKey string) { + DefaultClient.APIKey = apiKey +} + +// GetObject gets an Object. +func GetObject( + ctx context.Context, + opts GetObjectOpts, +) (Object, error) { + return DefaultClient.GetObject(ctx, opts) +} + +// ListObjects gets a list of Objects. +func ListObjects( + ctx context.Context, + opts ListObjectsOpts, +) (ListObjectsResponse, error) { + return DefaultClient.ListObjects(ctx, opts) +} + +// CreateObject creates an Object. +func CreateObject( + ctx context.Context, + opts CreateObjectOpts, +) (Object, error) { + return DefaultClient.CreateObject(ctx, opts) +} + +// UpdateObject updates an Object. +func UpdateObject( + ctx context.Context, + opts UpdateObjectOpts, +) (Object, error) { + return DefaultClient.UpdateObject(ctx, opts) +} + +// DeleteObject deletes an Object. +func DeleteObject( + ctx context.Context, + opts DeleteObjectOpts, +) error { + return DefaultClient.DeleteObject(ctx, opts) +} + +// ListWarrants gets a list of Warrants. +func ListWarrants( + ctx context.Context, + opts ListWarrantsOpts, +) (ListWarrantsResponse, error) { + return DefaultClient.ListWarrants(ctx, opts) +} + +// WriteWarrant performs a write operation on a Warrant. +func WriteWarrant( + ctx context.Context, + opts WriteWarrantOpts, +) (WriteWarrantResponse, error) { + return DefaultClient.WriteWarrant(ctx, opts) +} + +// BatchWriteWarrants performs write operations on multiple Warrants in one request. +func BatchWriteWarrants( + ctx context.Context, + opts []WriteWarrantOpts, +) (WriteWarrantResponse, error) { + return DefaultClient.BatchWriteWarrants(ctx, opts) +} + +// Check performs an access check on a Warrant. +func Check( + ctx context.Context, + opts CheckOpts, +) (bool, error) { + return DefaultClient.Check(ctx, opts) +} + +// CheckMany performs access checks on multiple Warrants. +func CheckMany( + ctx context.Context, + opts CheckManyOpts, +) (bool, error) { + return DefaultClient.CheckMany(ctx, opts) +} + +// BatchCheck performs individual access checks on multiple Warrants in one request. +func BatchCheck( + ctx context.Context, + opts BatchCheckOpts, +) ([]bool, error) { + return DefaultClient.BatchCheck(ctx, opts) +} + +// Query performs a query for a set of resources. +func Query( + ctx context.Context, + opts QueryOpts, +) (QueryResponse, error) { + return DefaultClient.Query(ctx, opts) +} diff --git a/pkg/fga/fga_test.go b/pkg/fga/fga_test.go new file mode 100644 index 00000000..eeb5da69 --- /dev/null +++ b/pkg/fga/fga_test.go @@ -0,0 +1,348 @@ +package fga + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/workos/workos-go/v4/pkg/common" +) + +func TestFGAGetObject(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(getObjectTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := Object{ + ObjectType: "report", + ObjectId: "ljc_1029", + } + objectResponse, err := GetObject(context.Background(), GetObjectOpts{ + ObjectType: "report", + ObjectId: "ljc_1029", + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, objectResponse) +} + +func TestFGAListObjects(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(listObjectsTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := ListObjectsResponse{ + Data: []Object{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + }, + { + ObjectType: "report", + ObjectId: "mso_0806", + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + } + objectsResponse, err := ListObjects(context.Background(), ListObjectsOpts{ + ObjectType: "report", + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, objectsResponse) +} + +func TestFGACreateObject(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(createObjectTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := Object{ + ObjectType: "report", + ObjectId: "sso_1710", + } + createdObject, err := CreateObject(context.Background(), CreateObjectOpts{ + ObjectType: "report", + ObjectId: "sso_1710", + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, createdObject) +} + +func TestFGAUpdateObject(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(updateObjectTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := Object{ + ObjectType: "report", + ObjectId: "lad_8812", + Meta: map[string]interface{}{ + "description": "Updated report", + }, + } + updatedObject, err := UpdateObject(context.Background(), UpdateObjectOpts{ + ObjectType: "report", + ObjectId: "lad_8812", + Meta: map[string]interface{}{ + "description": "Updated report", + }, + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, updatedObject) +} + +func TestFGADeleteObject(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(deleteObjectTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + err := DeleteObject(context.Background(), DeleteObjectOpts{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }) + + require.NoError(t, err) +} + +func TestFGAListWarrants(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(listWarrantsTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := ListWarrantsResponse{ + Data: []Warrant{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + { + ObjectType: "report", + ObjectId: "aut_7403", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + } + warrantsResponse, err := ListWarrants(context.Background(), ListWarrantsOpts{ + ObjectType: "report", + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, warrantsResponse) +} + +func TestFGAWriteWarrant(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(writeWarrantTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + } + warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ + Op: "create", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, warrantResponse) +} + +func TestFGABatchWriteWarrants(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(writeWarrantTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := WriteWarrantResponse{ + WarrantToken: "new_warrant_token", + } + warrantResponse, err := BatchWriteWarrants(context.Background(), []WriteWarrantOpts{ + { + Op: "delete", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "viewer", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + { + Op: "create", + ObjectType: "report", + ObjectId: "sso_1710", + Relation: "editor", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, warrantResponse) +} + +func TestFGACheckMany(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(checkManyTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + checkResponse, err := CheckMany(context.Background(), CheckManyOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }) + + require.NoError(t, err) + require.True(t, checkResponse) +} + +func TestFGABatchCheck(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(batchCheckTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + checkResponses, err := BatchCheck(context.Background(), BatchCheckOpts{ + Warrants: []WarrantCheck{ + { + ObjectType: "report", + ObjectId: "ljc_1029", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }) + + require.NoError(t, err) + require.Len(t, checkResponses, 2) + require.True(t, checkResponses[0]) + require.False(t, checkResponses[1]) +} + +func TestFGAQuery(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(queryTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := QueryResponse{ + Data: []QueryResult{ + { + ObjectType: "role", + ObjectId: "role_01SXW182", + Relation: "member", + Warrant: Warrant{ + ObjectType: "role", + ObjectId: "role_01SXW182", + Relation: "member", + Subject: Subject{ + ObjectType: "user", + ObjectId: "user_01SXW182", + }, + }, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + } + queryResponse, err := Query(context.Background(), QueryOpts{ + Query: "select role where user:user_01SXW182 is member", + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, queryResponse) +} From ead7d01e8b8072a637c2d9791a0ca533e9c5fe0b Mon Sep 17 00:00:00 2001 From: Stanley Phu Date: Thu, 4 Jul 2024 14:07:10 -0700 Subject: [PATCH 3/7] Add object types list and batch update methods --- pkg/fga/client.go | 121 +++++++++++++++++++++ pkg/fga/client_test.go | 235 +++++++++++++++++++++++++++++++++++++++++ pkg/fga/fga.go | 16 +++ pkg/fga/fga_test.go | 93 ++++++++++++++++ 4 files changed, 465 insertions(+) diff --git a/pkg/fga/client.go b/pkg/fga/client.go index 0a9e0272..8903b89c 100644 --- a/pkg/fga/client.go +++ b/pkg/fga/client.go @@ -139,6 +139,45 @@ type DeleteObjectOpts struct { ObjectId string } +// Object types +type ObjectType struct { + // Unique string ID of the object type. + Type string `json:"type"` + + // Set of relationships that subjects can have on objects of this type. + Relations map[string]interface{} `json:"relations"` +} + +type ListObjectTypesOpts struct { + // Maximum number of records to return. + Limit int `url:"limit,omitempty"` + + // The order in which to paginate records. + Order Order `url:"order,omitempty"` + + // Pagination cursor to receive records before a provided ObjectType ID. + Before string `url:"before,omitempty"` + + // Pagination cursor to receive records after a provided ObjectType ID. + After string `url:"after,omitempty"` +} + +type ListObjectTypesResponse struct { + // List of Object Types. + Data []ObjectType `json:"data"` + + // Cursor pagination options. + ListMetadata common.ListMetadata `json:"list_metadata"` +} + +type UpdateObjectTypeOpts struct { + // Unique string ID of the object type. + Type string `json:"type"` + + // Set of relationships that subjects can have on objects of this type. + Relations map[string]interface{} `json:"relations"` +} + // Warrants type Subject struct { // The type of the subject. @@ -527,6 +566,88 @@ func (c *Client) DeleteObject(ctx context.Context, opts DeleteObjectOpts) error return workos_errors.TryGetHTTPError(res) } +// ListObjectTypes gets a list of FGA object types. +func (c *Client) ListObjectTypes(ctx context.Context, opts ListObjectTypesOpts) (ListObjectTypesResponse, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/fga/v1/object-types", c.Endpoint) + req, err := http.NewRequest(http.MethodGet, endpoint, nil) + if err != nil { + return ListObjectTypesResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + if opts.Limit == 0 { + opts.Limit = ResponseLimit + } + + if opts.Order == "" { + opts.Order = Desc + } + + q, err := query.Values(opts) + if err != nil { + return ListObjectTypesResponse{}, err + } + + req.URL.RawQuery = q.Encode() + + res, err := c.HTTPClient.Do(req) + if err != nil { + return ListObjectTypesResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return ListObjectTypesResponse{}, err + } + + var body ListObjectTypesResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// BatchUpdateObjectTypes sets the environment's set of object types to match the object types passed. +func (c *Client) BatchUpdateObjectTypes(ctx context.Context, opts []UpdateObjectTypeOpts) ([]ObjectType, error) { + c.once.Do(c.init) + + data, err := c.JSONEncode(opts) + if err != nil { + return []ObjectType{}, err + } + + endpoint := fmt.Sprintf("%s/fga/v1/object-types", c.Endpoint) + req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(data)) + if err != nil { + return []ObjectType{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return []ObjectType{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return []ObjectType{}, err + } + + var body []ObjectType + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + // ListWarrants gets a list of Warrants. func (c *Client) ListWarrants(ctx context.Context, opts ListWarrantsOpts) (ListWarrantsResponse, error) { c.once.Do(c.init) diff --git a/pkg/fga/client_test.go b/pkg/fga/client_test.go index 3b496f94..b7ed7f3c 100644 --- a/pkg/fga/client_test.go +++ b/pkg/fga/client_test.go @@ -184,6 +184,241 @@ func listObjectsTestHandler(w http.ResponseWriter, r *http.Request) { w.Write(body) } +func TestListObjectTypes(t *testing.T) { + tests := []struct { + scenario string + client *Client + options ListObjectTypesOpts + expected ListObjectTypesResponse + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns ObjectTypes", + client: &Client{ + APIKey: "test", + }, + options: ListObjectTypesOpts{ + Order: "asc", + }, + + expected: ListObjectTypesResponse{ + Data: []ObjectType{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(listObjectTypesTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + objectTypes, err := client.ListObjectTypes(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, objectTypes) + }) + } +} + +func listObjectTypesTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal(struct { + ListObjectTypesResponse + }{ + ListObjectTypesResponse: ListObjectTypesResponse{ + Data: []ObjectType{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + }, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func TestBatchUpdateObjectTypes(t *testing.T) { + tests := []struct { + scenario string + client *Client + options []UpdateObjectTypeOpts + expected []ObjectType + err bool + }{ + { + scenario: "Request without API Key returns an error", + client: &Client{}, + err: true, + }, + { + scenario: "Request returns ObjectTypes", + client: &Client{ + APIKey: "test", + }, + options: []UpdateObjectTypeOpts{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }, + + expected: []ObjectType{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(batchUpdateObjectTypesTestHandler)) + defer server.Close() + + client := test.client + client.Endpoint = server.URL + client.HTTPClient = server.Client() + + objectTypes, err := client.BatchUpdateObjectTypes(context.Background(), test.options) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.expected, objectTypes) + }) + } +} + +func batchUpdateObjectTypesTestHandler(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test" { + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") { + w.WriteHeader(http.StatusBadRequest) + return + } + + body, err := json.Marshal([]ObjectType{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(body) +} + func TestCreateObject(t *testing.T) { tests := []struct { scenario string diff --git a/pkg/fga/fga.go b/pkg/fga/fga.go index 787e0cfc..0191053b 100644 --- a/pkg/fga/fga.go +++ b/pkg/fga/fga.go @@ -54,6 +54,22 @@ func DeleteObject( return DefaultClient.DeleteObject(ctx, opts) } +// ListObjectTypes gets a list of ObjectTypes. +func ListObjectTypes( + ctx context.Context, + opts ListObjectTypesOpts, +) (ListObjectTypesResponse, error) { + return DefaultClient.ListObjectTypes(ctx, opts) +} + +// BatchUpdateObjectTypes sets the environment's object types to match the provided types. +func BatchUpdateObjectTypes( + ctx context.Context, + opts []UpdateObjectTypeOpts, +) ([]ObjectType, error) { + return DefaultClient.BatchUpdateObjectTypes(ctx, opts) +} + // ListWarrants gets a list of Warrants. func ListWarrants( ctx context.Context, diff --git a/pkg/fga/fga_test.go b/pkg/fga/fga_test.go index eeb5da69..820c9a78 100644 --- a/pkg/fga/fga_test.go +++ b/pkg/fga/fga_test.go @@ -137,6 +137,99 @@ func TestFGADeleteObject(t *testing.T) { require.NoError(t, err) } +func TestFGAListObjectTypes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(listObjectTypesTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := ListObjectTypesResponse{ + Data: []ObjectType{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }, + ListMetadata: common.ListMetadata{ + Before: "", + After: "", + }, + } + objectTypesResponse, err := ListObjectTypes(context.Background(), ListObjectTypesOpts{ + Order: "asc", + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, objectTypesResponse) +} + +func TestFGABatchUpdateObjectTypes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(batchUpdateObjectTypesTestHandler)) + defer server.Close() + + DefaultClient = &Client{ + HTTPClient: server.Client(), + Endpoint: server.URL, + } + SetAPIKey("test") + + expectedResponse := []ObjectType{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + } + objectTypes, err := BatchUpdateObjectTypes(context.Background(), []UpdateObjectTypeOpts{ + { + Type: "report", + Relations: map[string]interface{}{ + "owner": map[string]interface{}{}, + "editor": map[string]interface{}{ + "inherit_if": "owner", + }, + "viewer": map[string]interface{}{ + "inherit_if": "editor", + }, + }, + }, + { + Type: "user", + Relations: map[string]interface{}{}, + }, + }) + + require.NoError(t, err) + require.Equal(t, expectedResponse, objectTypes) +} + func TestFGAListWarrants(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(listWarrantsTestHandler)) defer server.Close() From 4d5ebf6152a9fd7d2eda3787b8afd70a9f5dbba8 Mon Sep 17 00:00:00 2001 From: Karan Kajla Date: Thu, 4 Jul 2024 18:40:56 -0700 Subject: [PATCH 4/7] Add Authorized method to CheckResponse and return CheckResponse from all check methods --- pkg/fga/client.go | 43 +++++++++++++++---------------- pkg/fga/client_live_example.go | 46 +++++++++++++++++----------------- pkg/fga/fga.go | 6 ++--- 3 files changed, 48 insertions(+), 47 deletions(-) diff --git a/pkg/fga/client.go b/pkg/fga/client.go index 8903b89c..b552b76b 100644 --- a/pkg/fga/client.go +++ b/pkg/fga/client.go @@ -23,8 +23,9 @@ type Order string // Constants that enumerate the available orders. const ( - Asc Order = "asc" - Desc Order = "desc" + CheckResultAuthorized = "Authorized" + Asc Order = "asc" + Desc Order = "desc" ) // Client represents a client that performs FGA requests to the WorkOS API. @@ -333,6 +334,10 @@ type CheckResponse struct { DecisionPath map[string][]Warrant `json:"decision_path,omitempty"` } +func (checkResponse CheckResponse) Authorized() bool { + return checkResponse.Result == CheckResultAuthorized +} + // Query type QueryOpts struct { // Query to be executed. @@ -765,7 +770,7 @@ func (c *Client) BatchWriteWarrants(ctx context.Context, opts []WriteWarrantOpts return body, err } -func (c *Client) Check(ctx context.Context, opts CheckOpts) (bool, error) { +func (c *Client) Check(ctx context.Context, opts CheckOpts) (CheckResponse, error) { return c.CheckMany(ctx, CheckManyOpts{ Warrants: []WarrantCheck{opts.Warrant}, Debug: opts.Debug, @@ -773,18 +778,18 @@ func (c *Client) Check(ctx context.Context, opts CheckOpts) (bool, error) { }) } -func (c *Client) CheckMany(ctx context.Context, opts CheckManyOpts) (bool, error) { +func (c *Client) CheckMany(ctx context.Context, opts CheckManyOpts) (CheckResponse, error) { c.once.Do(c.init) data, err := c.JSONEncode(opts) if err != nil { - return false, err + return CheckResponse{}, err } endpoint := fmt.Sprintf("%s/fga/v1/check", c.Endpoint) req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) if err != nil { - return false, err + return CheckResponse{}, err } req = req.WithContext(ctx) @@ -797,25 +802,25 @@ func (c *Client) CheckMany(ctx context.Context, opts CheckManyOpts) (bool, error res, err := c.HTTPClient.Do(req) if err != nil { - return false, err + return CheckResponse{}, err } defer res.Body.Close() if err = workos_errors.TryGetHTTPError(res); err != nil { - return false, err + return CheckResponse{}, err } var checkResponse CheckResponse dec := json.NewDecoder(res.Body) err = dec.Decode(&checkResponse) if err != nil { - return false, err + return CheckResponse{}, err } - return checkResponse.Result == "Authorized", nil + return checkResponse, nil } -func (c *Client) BatchCheck(ctx context.Context, opts BatchCheckOpts) ([]bool, error) { +func (c *Client) BatchCheck(ctx context.Context, opts BatchCheckOpts) ([]CheckResponse, error) { c.once.Do(c.init) checkOpts := CheckManyOpts{ @@ -826,13 +831,13 @@ func (c *Client) BatchCheck(ctx context.Context, opts BatchCheckOpts) ([]bool, e } data, err := c.JSONEncode(checkOpts) if err != nil { - return []bool{}, err + return []CheckResponse{}, err } endpoint := fmt.Sprintf("%s/fga/v1/check", c.Endpoint) req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) if err != nil { - return []bool{}, err + return []CheckResponse{}, err } req = req.WithContext(ctx) @@ -845,26 +850,22 @@ func (c *Client) BatchCheck(ctx context.Context, opts BatchCheckOpts) ([]bool, e res, err := c.HTTPClient.Do(req) if err != nil { - return []bool{}, err + return []CheckResponse{}, err } defer res.Body.Close() if err = workos_errors.TryGetHTTPError(res); err != nil { - return []bool{}, err + return []CheckResponse{}, err } var checkResponses []CheckResponse dec := json.NewDecoder(res.Body) err = dec.Decode(&checkResponses) if err != nil { - return []bool{}, err + return []CheckResponse{}, err } - var results []bool - for _, checkResponse := range checkResponses { - results = append(results, checkResponse.Result == "Authorized") - } - return results, nil + return checkResponses, nil } // Query executes a query for a set of resources. diff --git a/pkg/fga/client_live_example.go b/pkg/fga/client_live_example.go index 230f32ef..a2d6583e 100644 --- a/pkg/fga/client_live_example.go +++ b/pkg/fga/client_live_example.go @@ -392,7 +392,7 @@ func TestRBAC(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, adminUserHasPermission) + require.False(t, adminUserHasPermission.Authorized()) // Assign create-report permission -> admin role -> admin user warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ @@ -438,7 +438,7 @@ func TestRBAC(t *testing.T) { if err != nil { t.Fatal(err) } - require.True(t, adminUserHasPermission) + require.True(t, adminUserHasPermission.Authorized()) adminUserRolesList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select role where user:%s is member", adminUser.ObjectId), @@ -518,7 +518,7 @@ func TestRBAC(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, adminUserHasPermission) + require.False(t, adminUserHasPermission.Authorized()) adminUserRolesList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select role where user:%s is member", adminUser.ObjectId), @@ -556,7 +556,7 @@ func TestRBAC(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, viewerUserHasPermission) + require.False(t, viewerUserHasPermission.Authorized()) viewerUserPermissionsList, err := Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select permission where user:%s is member", viewerUser.ObjectId), @@ -597,7 +597,7 @@ func TestRBAC(t *testing.T) { if err != nil { t.Fatal(err) } - require.True(t, viewerUserHasPermission) + require.True(t, viewerUserHasPermission.Authorized()) viewerUserPermissionsList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select permission where user:%s is member", viewerUser.ObjectId), @@ -645,7 +645,7 @@ func TestRBAC(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, viewerUserHasPermission) + require.False(t, viewerUserHasPermission.Authorized()) viewerUserPermissionsList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select permission where user:%s is member", viewerUser.ObjectId), @@ -780,7 +780,7 @@ func TestPricingTiersFeaturesAndUsers(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, paidUserHasFeature) + require.False(t, paidUserHasFeature.Authorized()) paidUserFeaturesList, err := Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select feature where user:%s is member", paidUser.ObjectId), @@ -821,7 +821,7 @@ func TestPricingTiersFeaturesAndUsers(t *testing.T) { if err != nil { t.Fatal(err) } - require.True(t, paidUserHasFeature) + require.True(t, paidUserHasFeature.Authorized()) paidUserFeaturesList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select feature where user:%s is member", paidUser.ObjectId), @@ -868,7 +868,7 @@ func TestPricingTiersFeaturesAndUsers(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, paidUserHasFeature) + require.False(t, paidUserHasFeature.Authorized()) paidUserFeaturesList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select feature where user:%s is member", paidUser.ObjectId), @@ -896,7 +896,7 @@ func TestPricingTiersFeaturesAndUsers(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, freeUserHasFeature) + require.False(t, freeUserHasFeature.Authorized()) freeUserFeaturesList, err := Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select feature where user:%s is member", freeUser.ObjectId), @@ -961,7 +961,7 @@ func TestPricingTiersFeaturesAndUsers(t *testing.T) { if err != nil { t.Fatal(err) } - require.True(t, freeUserHasFeature) + require.True(t, freeUserHasFeature.Authorized()) freeUserFeaturesList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select feature where user:%s is member", freeUser.ObjectId), @@ -1036,7 +1036,7 @@ func TestPricingTiersFeaturesAndUsers(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, freeUserHasFeature) + require.False(t, freeUserHasFeature.Authorized()) freeUserFeaturesList, err = Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select feature where user:%s is member", freeUser.ObjectId), @@ -1154,7 +1154,7 @@ func TestWarrants(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, userHasPermission) + require.False(t, userHasPermission.Authorized()) warrantResponse, err := WriteWarrant(context.Background(), WriteWarrantOpts{ ObjectType: newPermission.ObjectType, @@ -1243,7 +1243,7 @@ func TestWarrants(t *testing.T) { if err != nil { t.Fatal(err) } - require.True(t, userHasPermission) + require.True(t, userHasPermission.Authorized()) queryResponse, err := Query(context.Background(), QueryOpts{ Query: fmt.Sprintf("select permission where user:%s is member", user1.ObjectId), @@ -1287,7 +1287,7 @@ func TestWarrants(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, userHasPermission) + require.False(t, userHasPermission.Authorized()) // Clean up err = DeleteObject(context.Background(), DeleteObjectOpts{ @@ -1372,8 +1372,8 @@ func TestBatchWarrants(t *testing.T) { t.Fatal(err) } require.Len(t, userHasPermissions, 2) - require.False(t, userHasPermissions[0]) - require.False(t, userHasPermissions[1]) + require.False(t, userHasPermissions[0].Authorized()) + require.False(t, userHasPermissions[1].Authorized()) warrantResponse, err := BatchWriteWarrants(context.Background(), []WriteWarrantOpts{ { @@ -1428,8 +1428,8 @@ func TestBatchWarrants(t *testing.T) { t.Fatal(err) } require.Len(t, userHasPermissions, 2) - require.True(t, userHasPermissions[0]) - require.True(t, userHasPermissions[1]) + require.True(t, userHasPermissions[0].Authorized()) + require.True(t, userHasPermissions[1].Authorized()) warrantResponse, err = BatchWriteWarrants(context.Background(), []WriteWarrantOpts{ { @@ -1485,8 +1485,8 @@ func TestBatchWarrants(t *testing.T) { t.Fatal(err) } require.Len(t, userHasPermissions, 2) - require.False(t, userHasPermissions[0]) - require.False(t, userHasPermissions[1]) + require.False(t, userHasPermissions[0].Authorized()) + require.False(t, userHasPermissions[1].Authorized()) // Clean up err = DeleteObject(context.Background(), DeleteObjectOpts{ @@ -1548,7 +1548,7 @@ func TestWarrantsWithPolicy(t *testing.T) { if err != nil { t.Fatal(err) } - require.True(t, checkResult) + require.True(t, checkResult.Authorized()) checkResult, err = Check(context.Background(), CheckOpts{ Warrant: WarrantCheck{ @@ -1568,7 +1568,7 @@ func TestWarrantsWithPolicy(t *testing.T) { if err != nil { t.Fatal(err) } - require.False(t, checkResult) + require.False(t, checkResult.Authorized()) warrantResponse, err = WriteWarrant(context.Background(), WriteWarrantOpts{ Op: "delete", diff --git a/pkg/fga/fga.go b/pkg/fga/fga.go index 0191053b..0e052cee 100644 --- a/pkg/fga/fga.go +++ b/pkg/fga/fga.go @@ -98,7 +98,7 @@ func BatchWriteWarrants( func Check( ctx context.Context, opts CheckOpts, -) (bool, error) { +) (CheckResponse, error) { return DefaultClient.Check(ctx, opts) } @@ -106,7 +106,7 @@ func Check( func CheckMany( ctx context.Context, opts CheckManyOpts, -) (bool, error) { +) (CheckResponse, error) { return DefaultClient.CheckMany(ctx, opts) } @@ -114,7 +114,7 @@ func CheckMany( func BatchCheck( ctx context.Context, opts BatchCheckOpts, -) ([]bool, error) { +) ([]CheckResponse, error) { return DefaultClient.BatchCheck(ctx, opts) } From 386a3f417dbe96b370e4f0218b57741681fedba7 Mon Sep 17 00:00:00 2001 From: Rakesh Patel Date: Wed, 3 Jul 2024 13:57:46 -0700 Subject: [PATCH 5/7] Add authentication method options to organization --- pkg/organizations/client.go | 77 +++++++++++++++++++++++++++++++- pkg/organizations/client_test.go | 16 ++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/pkg/organizations/client.go b/pkg/organizations/client.go index c4a5fbf8..4f05e02e 100644 --- a/pkg/organizations/client.go +++ b/pkg/organizations/client.go @@ -83,6 +83,27 @@ type Organization struct { // Deprecated: If you need to allow sign-ins from any email domain, contact support@workos.com. AllowProfilesOutsideOrganization bool `json:"allow_profiles_outside_organization"` + // Whether the Organization has Magic Link authentication enabled. + MagicLinkAuthEnabled bool `json:"magic_link_auth_enabled"` + + // Whether the Organization has Password authentication enabled. + PasswordAuthEnabled bool `json:"password_auth_enabled"` + + // Whether the Organization has Apple OAuth authentication enabled. + AppleOauthAuthEnabled bool `json:"apple_oauth_auth_enabled"` + + // Whether the Organization has Google OAuth authentication enabled. + GoogleOauthAuthEnabled bool `json:"google_oauth_auth_enabled"` + + // Whether the Organization has Azure AD OAuth authentication enabled. + MicrosoftOauthAuthEnabled bool `json:"microsoft_oauth_auth_enabled"` + + // Whether the Organization has GitHub OAuth authentication enabled. + GithubOauthAuthEnabled bool `json:"github_oauth_auth_enabled"` + + // Whether the Organization has Domain MFA required. + DomainMfaRequired bool `json:"domain_mfa_required"` + // The Organization's Domains. Domains []OrganizationDomain `json:"domains"` @@ -185,6 +206,27 @@ type UpdateOrganizationOpts struct { // Deprecated: Use DomainData instead. Domains []string + // Whether the Organization has Magic Link authentication enabled. + MagicLinkAuthEnabled bool `json:"magic_link_auth_enabled"` + + // Whether the Organization has Password authentication enabled. + PasswordAuthEnabled bool `json:"password_auth_enabled"` + + // Whether the Organization has Apple OAuth authentication enabled. + AppleOauthAuthEnabled bool `json:"apple_oauth_auth_enabled"` + + // Whether the Organization has Google OAuth authentication enabled. + GoogleOauthAuthEnabled bool `json:"google_oauth_auth_enabled"` + + // Whether the Organization has Azure AD OAuth authentication enabled. + MicrosoftOauthAuthEnabled bool `json:"microsoft_oauth_auth_enabled"` + + // Whether the Organization has GitHub OAuth authentication enabled. + GithubOauthAuthEnabled bool `json:"github_oauth_auth_enabled"` + + // Whether the Organization has Domain MFA required. + DomainMfaRequired bool `json:"domain_mfa_required"` + // Domains of the Organization. DomainData []OrganizationDomainData `json:"domain_data"` } @@ -340,9 +382,42 @@ func (c *Client) UpdateOrganization(ctx context.Context, opts UpdateOrganization // // Deprecated: Use DomainData instead. Domains []string `json:"domains,omitempty"` + + // Whether the Organization has Magic Link authentication enabled. + MagicLinkAuthEnabled bool `json:"magic_link_auth_enabled"` + + // Whether the Organization has Password authentication enabled. + PasswordAuthEnabled bool `json:"password_auth_enabled"` + + // Whether the Organization has Apple OAuth authentication enabled. + AppleOauthAuthEnabled bool `json:"apple_oauth_auth_enabled"` + + // Whether the Organization has Google OAuth authentication enabled. + GoogleOauthAuthEnabled bool `json:"google_oauth_auth_enabled"` + + // Whether the Organization has Azure AD OAuth authentication enabled. + MicrosoftOauthAuthEnabled bool `json:"microsoft_oauth_auth_enabled"` + + // Whether the Organization has GitHub OAuth authentication enabled. + GithubOauthAuthEnabled bool `json:"github_oauth_auth_enabled"` + + // Whether the Organization has Domain MFA required. + DomainMfaRequired bool `json:"domain_mfa_required"` } - update_opts := UpdateOrganizationChangeOpts{opts.Name, opts.AllowProfilesOutsideOrganization, opts.DomainData, opts.Domains} + update_opts := UpdateOrganizationChangeOpts{ + opts.Name, + opts.AllowProfilesOutsideOrganization, + opts.DomainData, + opts.Domains, + opts.MagicLinkAuthEnabled, + opts.PasswordAuthEnabled, + opts.AppleOauthAuthEnabled, + opts.GoogleOauthAuthEnabled, + opts.MicrosoftOauthAuthEnabled, + opts.GithubOauthAuthEnabled, + opts.DomainMfaRequired, + } data, err := c.JSONEncode(update_opts) if err != nil { diff --git a/pkg/organizations/client_test.go b/pkg/organizations/client_test.go index f2426f35..4f37effb 100644 --- a/pkg/organizations/client_test.go +++ b/pkg/organizations/client_test.go @@ -38,11 +38,18 @@ func TestGetOrganization(t *testing.T) { Name: "Foo Corp", AllowProfilesOutsideOrganization: false, Domains: []OrganizationDomain{ - OrganizationDomain{ + { ID: "organization_domain_id", Domain: "foo-corp.com", }, }, + MagicLinkAuthEnabled: false, + PasswordAuthEnabled: false, + AppleOauthAuthEnabled: false, + GoogleOauthAuthEnabled: false, + GithubOauthAuthEnabled: false, + MicrosoftOauthAuthEnabled: false, + DomainMfaRequired: false, }, }, } @@ -84,6 +91,13 @@ func getOrganizationTestHandler(w http.ResponseWriter, r *http.Request) { Domain: "foo-corp.com", }, }, + MagicLinkAuthEnabled: false, + PasswordAuthEnabled: false, + AppleOauthAuthEnabled: false, + GoogleOauthAuthEnabled: false, + GithubOauthAuthEnabled: false, + MicrosoftOauthAuthEnabled: false, + DomainMfaRequired: false, }) if err != nil { w.WriteHeader(http.StatusInternalServerError) From 84a9662db6be4787178b08b5994efc6c46f34ab9 Mon Sep 17 00:00:00 2001 From: Rakesh Patel Date: Thu, 4 Jul 2024 12:53:07 -0700 Subject: [PATCH 6/7] Add oauthcredentials client --- pkg/oauthcredentials/client.go | 345 +++++++++++++++++++++++ pkg/oauthcredentials/oauthcredentials.go | 42 +++ 2 files changed, 387 insertions(+) create mode 100644 pkg/oauthcredentials/client.go create mode 100644 pkg/oauthcredentials/oauthcredentials.go diff --git a/pkg/oauthcredentials/client.go b/pkg/oauthcredentials/client.go new file mode 100644 index 00000000..f73e2d93 --- /dev/null +++ b/pkg/oauthcredentials/client.go @@ -0,0 +1,345 @@ +package oauthcredentials + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/workos/workos-go/v4/pkg/workos_errors" + + "github.com/google/go-querystring/query" + "github.com/workos/workos-go/v4/internal/workos" + "github.com/workos/workos-go/v4/pkg/common" +) + +// ResponseLimit is the default number of records to limit a response to. +const ResponseLimit = 10 + +// Order represents the order of records. +type Order string + +// Constants that enumerate the available orders. +const ( + Asc Order = "asc" + Desc Order = "desc" +) + +// Client represents a client that performs OAuthCredential requests to the WorkOS API. +type Client struct { + // The WorkOS API Key. It can be found in https://dashboard.workos.com/api-keys. + APIKey string + + // The http.Client that is used to manage OAuthCredential records from WorkOS. + // Defaults to http.Client. + HTTPClient *http.Client + + // The endpoint to WorkOS API. Defaults to https://api.workos.com. + Endpoint string + + // The function used to encode in JSON. Defaults to json.Marshal. + JSONEncode func(v interface{}) ([]byte, error) + + once sync.Once +} + +func (c *Client) init() { + if c.HTTPClient == nil { + c.HTTPClient = &http.Client{Timeout: 10 * time.Second} + } + + if c.Endpoint == "" { + c.Endpoint = "https://api.workos.com" + } + + if c.JSONEncode == nil { + c.JSONEncode = json.Marshal + } +} + +// OAuthConnectionType represents the type of OAuth Connection. +type OAuthConnectionType string + +// Constants that enumerate the available oauth connection types. +const ( + AppleOauth OAuthConnectionType = "AppleOauth" + GithubOauth OAuthConnectionType = "GitHubOauth" + GoogleOAuth OAuthConnectionType = "GoogleOAuth" + MicrosoftOAuth OAuthConnectionType = "MicrosoftOAuth" +) + +// OAuthConnectionState represents the state of an OAuth Connection. +type OAuthConnectionState string + +// Constants that enumerate the available oauth connection states. +const ( + Valid OAuthConnectionState = "Valid" + Invalid OAuthConnectionState = "Invalid" +) + +// OAuthCredential contains data about a WorkOS OauthCredential Auth Method. +type OAuthCredential struct { + // The OauthCredential's unique identifier. + ID string `json:"id"` + + // The OauthCredential's type. + Type OAuthConnectionType `json:"type"` + + // The OauthCredential's state. + State OAuthConnectionState `json:"state"` + + // The OauthCredential's external key. + ExternalKey string `json:"externalKey"` + + // The OauthCredential's client ID. + ClientID string `json:"clientId"` + + // The OauthCredential's client secret. + ClientSecret string `json:"clientSecret"` + + // The OauthCredential's redirect URI. + RedirectURI string `json:"redirectUri"` + + // The OauthCredential's userland enabled state. + IsUserlandEnabled bool `json:"isUserlandEnabled"` + + // The OauthCredential's Apple Team ID. + AppleTeamID string `json:"appleTeamId"` + + // The OauthCredential's Apple Key ID. + AppleKeyID string `json:"appleKeyId"` + + // The OauthCredential's Apple Private Key. + ApplePrivateKey string `json:"applePrivateKey"` + + // The timestamp of when the OAuthCredential was created. + CreatedAt string `json:"created_at"` + + // The timestamp of when the OAuthCredential was updated. + UpdatedAt string `json:"updated_at"` +} + +// GetOAuthCredentialOpts contains the options to request details for an OAuthCredential. +type GetOAuthCredentialOpts struct { + // Oauth Credential unique identifier. + ID string +} + +// ListOAuthCredentialsOpts contains the options to request OAuthCredentials. +type ListOAuthCredentialsOpts struct { + // Maximum number of records to return. + Limit int `url:"limit,omitempty"` + + // The order in which to paginate records. + Order Order `url:"order,omitempty"` + + // Pagination cursor to receive records before a provided OAuthCredential ID. + Before string `url:"before,omitempty"` + + // Pagination cursor to receive records after a provided OAuthCredential ID. + After string `url:"after,omitempty"` +} + +// ListOAuthCredentialsResponse describes the response structure when requesting +// OAuthCredentials +type ListOAuthCredentialsResponse struct { + // List of provisioned OAuthCredentials. + Data []OAuthCredential `json:"data"` + + // Cursor pagination options. + ListMetadata common.ListMetadata `json:"listMetadata"` +} + +// UpdateOAuthCredentialOpts contains the options to update an OAuthCredential. +type UpdateOAuthCredentialOpts struct { + // OAuthCredential unique identifier. + ID string + + // The OauthCredential's client ID. + ClientID string `json:"clientId"` + + // The OauthCredential's client secret. + ClientSecret string `json:"clientSecret"` + + // The OauthCredential's redirect URI. + RedirectURI string `json:"redirectUri"` + + // The OauthCredential's userland enabled state. + IsUserlandEnabled bool `json:"isUserlandEnabled"` + + // The OauthCredential's Apple Team ID. + AppleTeamID string `json:"appleTeamId"` + + // The OauthCredential's Apple Key ID. + AppleKeyID string `json:"appleKeyId"` + + // The OauthCredential's Apple Private Key. + ApplePrivateKey string `json:"applePrivateKey"` +} + +// GetOAuthCredential gets an OAuthCredential. +func (c *Client) GetOAuthCredential( + ctx context.Context, + opts GetOAuthCredentialOpts, +) (OAuthCredential, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf( + "%s/oauth-credentials/%s", + c.Endpoint, + opts.ID, + ) + req, err := http.NewRequest( + http.MethodGet, + endpoint, + nil, + ) + if err != nil { + return OAuthCredential{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return OAuthCredential{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return OAuthCredential{}, err + } + + var body OAuthCredential + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// ListOAuthCredentials gets a list of WorkOS OAuthCredentials. +func (c *Client) ListOAuthCredentials( + ctx context.Context, + opts ListOAuthCredentialsOpts, +) (ListOAuthCredentialsResponse, error) { + c.once.Do(c.init) + + endpoint := fmt.Sprintf("%s/oauth-credentials", c.Endpoint) + req, err := http.NewRequest( + http.MethodGet, + endpoint, + nil, + ) + if err != nil { + return ListOAuthCredentialsResponse{}, err + } + + req = req.WithContext(ctx) + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + if opts.Limit == 0 { + opts.Limit = ResponseLimit + } + + if opts.Order == "" { + opts.Order = Desc + } + + q, err := query.Values(opts) + if err != nil { + return ListOAuthCredentialsResponse{}, err + } + + req.URL.RawQuery = q.Encode() + + res, err := c.HTTPClient.Do(req) + if err != nil { + return ListOAuthCredentialsResponse{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return ListOAuthCredentialsResponse{}, err + } + + var body ListOAuthCredentialsResponse + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} + +// UpdateOAuthCredential updates an OAuthCredential. +func (c *Client) UpdateOAuthCredential(ctx context.Context, opts UpdateOAuthCredentialOpts) (OAuthCredential, error) { + c.once.Do(c.init) + + // UpdateOAuthCredentialChangeOpts contains the options to update an OAuthCredential minus the org ID + type UpdateOAuthCredentialChangeOpts struct { + // The OauthCredential's client ID. + ClientID string `json:"clientId"` + + // The OauthCredential's client secret. + ClientSecret string `json:"clientSecret"` + + // The OauthCredential's redirect URI. + RedirectURI string `json:"redirectUri"` + + // The OauthCredential's userland enabled state. + IsUserlandEnabled bool `json:"isUserlandEnabled"` + + // The OauthCredential's Apple Team ID. + AppleTeamID string `json:"appleTeamId"` + + // The OauthCredential's Apple Key ID. + AppleKeyID string `json:"appleKeyId"` + + // The OauthCredential's Apple Private Key. + ApplePrivateKey string `json:"applePrivateKey"` + } + + update_opts := UpdateOAuthCredentialChangeOpts{ + opts.ClientID, + opts.ClientSecret, + opts.RedirectURI, + opts.IsUserlandEnabled, + opts.AppleTeamID, + opts.AppleKeyID, + opts.ApplePrivateKey, + } + + data, err := c.JSONEncode(update_opts) + if err != nil { + return OAuthCredential{}, err + } + + endpoint := fmt.Sprintf("%s/organizations/%s", c.Endpoint, opts.ID) + req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(data)) + if err != nil { + return OAuthCredential{}, err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return OAuthCredential{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return OAuthCredential{}, err + } + + var body OAuthCredential + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} diff --git a/pkg/oauthcredentials/oauthcredentials.go b/pkg/oauthcredentials/oauthcredentials.go new file mode 100644 index 00000000..a1caf185 --- /dev/null +++ b/pkg/oauthcredentials/oauthcredentials.go @@ -0,0 +1,42 @@ +// Package `organizations` provides a client wrapping the WorkOS OAuthCredentials API. +package oauthcredentials + +import ( + "context" +) + +// DefaultClient is the client used by SetAPIKey and OAuthCredentials functions. +var ( + DefaultClient = &Client{ + Endpoint: "https://api.workos.com", + } +) + +// SetAPIKey sets the WorkOS API key for OAuthCredentials requests. +func SetAPIKey(apiKey string) { + DefaultClient.APIKey = apiKey +} + +// GetOAuthCredential gets an OAuthCredential. +func GetOAuthCredential( + ctx context.Context, + opts GetOAuthCredentialOpts, +) (OAuthCredential, error) { + return DefaultClient.GetOAuthCredential(ctx, opts) +} + +// ListOAuthCredentials gets a list of OAuthCredentials. +func ListOAuthCredentials( + ctx context.Context, + opts ListOAuthCredentialsOpts, +) (ListOAuthCredentialsResponse, error) { + return DefaultClient.ListOAuthCredentials(ctx, opts) +} + +// UpdateOAuthCredential creates an OAuthCredential. +func UpdateOAuthCredential( + ctx context.Context, + opts UpdateOAuthCredentialOpts, +) (OAuthCredential, error) { + return DefaultClient.UpdateOAuthCredential(ctx, opts) +} From d4beef1a79e576e7cee8d4df60eb9763547a4e48 Mon Sep 17 00:00:00 2001 From: Rakesh Patel Date: Thu, 4 Jul 2024 15:15:52 -0700 Subject: [PATCH 7/7] Add create oauth method calls --- pkg/oauthcredentials/client.go | 42 ++++++++++++++++++++++-- pkg/oauthcredentials/oauthcredentials.go | 8 +++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/pkg/oauthcredentials/client.go b/pkg/oauthcredentials/client.go index f73e2d93..ddfc28a5 100644 --- a/pkg/oauthcredentials/client.go +++ b/pkg/oauthcredentials/client.go @@ -67,8 +67,8 @@ type OAuthConnectionType string const ( AppleOauth OAuthConnectionType = "AppleOauth" GithubOauth OAuthConnectionType = "GitHubOauth" - GoogleOAuth OAuthConnectionType = "GoogleOAuth" - MicrosoftOAuth OAuthConnectionType = "MicrosoftOAuth" + GoogleOauth OAuthConnectionType = "GoogleOauth" + MicrosoftOauth OAuthConnectionType = "MicrosoftOauth" ) // OAuthConnectionState represents the state of an OAuth Connection. @@ -343,3 +343,41 @@ func (c *Client) UpdateOAuthCredential(ctx context.Context, opts UpdateOAuthCred err = dec.Decode(&body) return body, err } + +type CreateOAuthCredentialOpts struct { + Type OAuthConnectionType `json:"type"` +} + +func (c *Client) CreateOAuthCredential(ctx context.Context, opts CreateOAuthCredentialOpts) (OAuthCredential, error) { + c.once.Do(c.init) + + data, err := c.JSONEncode(opts) + if err != nil { + return OAuthCredential{}, err + } + + endpoint := fmt.Sprintf("%s/oauth-credentials", c.Endpoint) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return OAuthCredential{}, err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("User-Agent", "workos-go/"+workos.Version) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return OAuthCredential{}, err + } + defer res.Body.Close() + + if err = workos_errors.TryGetHTTPError(res); err != nil { + return OAuthCredential{}, err + } + + var body OAuthCredential + dec := json.NewDecoder(res.Body) + err = dec.Decode(&body) + return body, err +} diff --git a/pkg/oauthcredentials/oauthcredentials.go b/pkg/oauthcredentials/oauthcredentials.go index a1caf185..aba43c4d 100644 --- a/pkg/oauthcredentials/oauthcredentials.go +++ b/pkg/oauthcredentials/oauthcredentials.go @@ -40,3 +40,11 @@ func UpdateOAuthCredential( ) (OAuthCredential, error) { return DefaultClient.UpdateOAuthCredential(ctx, opts) } + +// CreateOAuthCredential creates an OAuthCredential. +func CreateOAuthCredential( + ctx context.Context, + opts CreateOAuthCredentialOpts, +) (OAuthCredential, error) { + return DefaultClient.CreateOAuthCredential(ctx, opts) +}