From 977754469b0224c80e4cc8055ceb26255a897f0d Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 7 Jan 2025 17:39:13 -0500 Subject: [PATCH] Address review feedback on SQL generator --- internal/datastore/common/schema.go | 8 +- internal/datastore/common/sql.go | 40 +- internal/datastore/common/sql_test.go | 366 ++++++++++++++++-- ...ions.go => zz_generated.schema_options.go} | 10 +- internal/datastore/crdb/crdb.go | 2 +- 5 files changed, 356 insertions(+), 70 deletions(-) rename internal/datastore/common/{schema_options.go => zz_generated.schema_options.go} (95%) diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go index 008c2857d2..542dc5dbf3 100644 --- a/internal/datastore/common/schema.go +++ b/internal/datastore/common/schema.go @@ -15,7 +15,7 @@ const ( // SchemaInformation holds the schema information from the SQL datastore implementation. // -//go:generate go run github.com/ecordell/optgen -output schema_options.go . SchemaInformation +//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation type SchemaInformation struct { RelationshipTableName string `debugmap:"visible"` @@ -47,8 +47,8 @@ type SchemaInformation struct { // ColumnOptimization is the optimization to use for columns in the schema, if any. ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` - // WithIntegrityColumns is a flag to indicate if the schema has integrity columns. - WithIntegrityColumns bool `debugmap:"visible"` + // IntegrityEnabled is a flag to indicate if the schema has integrity columns. + IntegrityEnabled bool `debugmap:"visible"` // ExpirationDisabled is a flag to indicate whether expiration support is disabled. ExpirationDisabled bool `debugmap:"visible"` @@ -102,7 +102,7 @@ func (si SchemaInformation) mustValidate() { panic("ColExpiration is required") } - if si.WithIntegrityColumns { + if si.IntegrityEnabled { if si.ColIntegrityKeyID == "" { panic("ColIntegrityKeyID is required") } diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 8bdc0f2653..db6b277df2 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -75,15 +75,15 @@ const ( // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. ColumnOptimizationOptionNone - // ColumnOptimizationOptionStaticValue is an option that optimizes the column for a static value. + // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values. ColumnOptimizationOptionStaticValues ) -type ColumnTracker struct { +type columnTracker struct { SingleValue *string } -type columnTrackerMap map[string]ColumnTracker +type columnTrackerMap map[string]columnTracker func (ctm columnTrackerMap) hasStaticValue(columnName string) bool { if r, ok := ctm[columnName]; ok && r.SingleValue != nil { @@ -119,7 +119,7 @@ func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filt return SchemaQueryFilterer{ schema: schema, queryBuilder: queryBuilder, - filteringColumnTracker: map[string]ColumnTracker{}, + filteringColumnTracker: map[string]columnTracker{}, filterMaximumIDCount: filterMaximumIDCount, isCustomQuery: false, extraFields: extraFields, @@ -140,7 +140,7 @@ func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQ return SchemaQueryFilterer{ schema: schema, queryBuilder: startingQuery, - filteringColumnTracker: map[string]ColumnTracker{}, + filteringColumnTracker: map[string]columnTracker{}, filterMaximumIDCount: filterMaximumIDCount, isCustomQuery: true, extraFields: nil, @@ -163,12 +163,12 @@ func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { spiceerrors.DebugAssert(func() bool { return sqf.isCustomQuery }, "UnderlyingQueryBuilder should only be called on custom queries") - return sqf.queryBuilderWithExpirationFilter(false) + return sqf.queryBuilderWithMaybeExpirationFilter(false) } -// queryBuilderWithExpirationFilter returns the query builder with the expiration filter applied, when necessary. +// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary. // Note that this adds the clause to the existing builder. -func (sqf SchemaQueryFilterer) queryBuilderWithExpirationFilter(skipExpiration bool) sq.SelectBuilder { +func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder { if sqf.schema.ExpirationDisabled || skipExpiration { return sqf.queryBuilder } @@ -319,15 +319,15 @@ func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string existing, ok := sqf.filteringColumnTracker[colName] if ok { if existing.SingleValue != nil && *existing.SingleValue != colValue { - sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} } } else { - sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: &colValue} + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue} } } func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) { - sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} } // FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the @@ -491,7 +491,7 @@ func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...data func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { selectorsOrClause := sq.Or{} - // If there is more than a single filter, record all the subjects as mutable, as the subjects returned + // If there is more than a single filter, record all the subjects as varying, as the subjects returned // can differ for each branch. // TODO(jschorr): Optimize this further where applicable. if len(selectors) > 1 { @@ -694,9 +694,9 @@ func (b RelationshipsQueryBuilder) withExpiration() bool { return !b.SkipExpiration && !b.Schema.ExpirationDisabled } -// withIntegrityColumns returns true if integrity columns should be included in the query. -func (b RelationshipsQueryBuilder) withIntegrityColumns() bool { - return b.Schema.WithIntegrityColumns +// integrityEnabled returns true if integrity columns should be included in the query. +func (b RelationshipsQueryBuilder) integrityEnabled() bool { + return b.Schema.IntegrityEnabled } // columnCount returns the number of columns that will be selected in the query. @@ -708,7 +708,7 @@ func (b RelationshipsQueryBuilder) columnCount() int { if b.withExpiration() { columnCount += relationshipExpirationColumnCount } - if b.withIntegrityColumns() { + if b.integrityEnabled() { columnCount += relationshipIntegrityColumnCount } return columnCount @@ -734,7 +734,7 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) } - if b.Schema.WithIntegrityColumns { + if b.integrityEnabled() { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) } @@ -742,14 +742,14 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = append(columnNamesToSelect, "1") } - sqlBuilder := b.baseQueryBuilder.queryBuilderWithExpirationFilter(b.SkipExpiration) + sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration) sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) return sqlBuilder.ToSql() } // FilteringValuesForTesting returns the filtering values. For test use only. -func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]ColumnTracker { +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker { return maps.Clone(b.filteringValues) } @@ -818,7 +818,7 @@ func ColumnsToSelect[CN any, CC any, EC any]( colsToSelect = append(colsToSelect, expiration) } - if b.Schema.WithIntegrityColumns { + if b.Schema.IntegrityEnabled { colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) } diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 2e19772245..62d6fbeea1 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/authzed/spicedb/pkg/datastore/options" @@ -803,7 +804,7 @@ func TestSchemaQueryFilterer(t *testing.T) { require.ElementsMatch(t, expected.staticCols, foundStaticColumns) - ran.queryBuilder = ran.queryBuilderWithExpirationFilter(test.withExpirationDisabled).Columns("*") + ran.queryBuilder = ran.queryBuilderWithMaybeExpirationFilter(test.withExpirationDisabled).Columns("*") sql, args, err := ran.queryBuilder.ToSql() require.NoError(t, err) @@ -822,49 +823,58 @@ func TestExecuteQuery(t *testing.T) { options []options.QueryOptionsOption expectedSQL string expectedArgs []any + expectedStaticColCount int expectedSkipCaveats bool expectedSkipExpiration bool withExpirationDisabled bool + withIntegrityEnabled bool + fromSuffix string + limit uint64 }{ { name: "filter by static resource type", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type and resource ID", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj") }, - expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj"}, + expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj"}, + expectedStaticColCount: 2, }, { name: "filter by static resource type and resource ID prefix", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterWithResourceIDPrefix("someprefix") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someprefix%"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someprefix%"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type and resource IDs", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}) }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "anotherobj"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "anotherobj"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type, resource ID and relation", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel"}, + expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel"}, + expectedStaticColCount: 3, }, { name: "filter by static resource type, resource ID, relation and subject type", @@ -873,8 +883,9 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, + expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, + expectedStaticColCount: 4, }, { name: "filter by static resource type, resource ID, relation, subject type and subject ID", @@ -884,8 +895,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, + expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, + expectedStaticColCount: 5, }, { name: "filter by static resource type, resource ID, relation, subject type, subject ID and subject relation", @@ -898,8 +910,9 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, }, { name: "filter by static everything without caveats", @@ -915,9 +928,10 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, }, { name: "filter by static everything (except one field) without caveats", @@ -933,9 +947,10 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 5, }, { name: "filter by static resource type with no caveats", @@ -945,9 +960,10 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColCount: 1, }, { name: "filter by just subject type", @@ -956,8 +972,9 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns"}, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns"}, + expectedStaticColCount: 1, }, { name: "filter by just subject type and subject ID", @@ -967,8 +984,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid"}, + expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid"}, + expectedStaticColCount: 2, }, { name: "filter by just subject type and subject relation", @@ -980,8 +998,9 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subrel"}, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subrel"}, + expectedStaticColCount: 2, }, { name: "filter by just subject type and subject ID and relation", @@ -994,8 +1013,9 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid", "subrel"}, + expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "subrel"}, + expectedStaticColCount: 3, }, { name: "filter by multiple subject types, but static subject ID", @@ -1008,8 +1028,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, + expectedStaticColCount: 1, }, { name: "multiple subjects filters with just types", @@ -1020,8 +1041,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + expectedStaticColCount: 0, }, { name: "multiple subjects filters with just types and static resource type", @@ -1032,8 +1054,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type with expiration disabled", @@ -1043,6 +1066,7 @@ func TestExecuteQuery(t *testing.T) { expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", expectedArgs: []any{"sometype"}, withExpirationDisabled: true, + expectedStaticColCount: 1, }, { name: "filter by static resource type with expiration skipped", @@ -1056,6 +1080,7 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipExpiration(true), }, + expectedStaticColCount: 1, }, { name: "filter by static resource type with expiration skipped and disabled", @@ -1069,6 +1094,182 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipExpiration(true), }, + expectedStaticColCount: 1, + }, + { + name: "with from suffix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples as of tomorrow WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + fromSuffix: "as of tomorrow", + expectedStaticColCount: 1, + }, + { + name: "with limit", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? LIMIT 65", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + limit: 65, + expectedStaticColCount: 1, + }, + { + name: "with integrity", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, integrity_key_id, integrity_hash, integrity_timestamp FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + withIntegrityEnabled: true, + expectedStaticColCount: 1, + }, + { + name: "all columns static with caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 6, + }, + { + name: "all columns static with expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedStaticColCount: 6, + }, + { + name: "all columns static with caveats and expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, + }, + { + name: "all columns static without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT 1 FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: -1, + }, + { + name: "one column not static", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + f := filterer.FilterToResourceType("sometype"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + + f2, _ := f.FilterToResourceIDs([]string{"foo", "bar"}) + return f2 + }, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND object_id IN (?,?)", + expectedArgs: []any{"sometype", "somerel", "subns", "subid", "subrel", "foo", "bar"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 5, + }, + { + name: "resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + f := filterer.FilterToResourceType("sometype"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + + f2, _ := f.FilterWithResourceIDPrefix("foo") + return f2 + }, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND object_id LIKE ?", + expectedArgs: []any{"sometype", "somerel", "subns", "subid", "subrel", "foo%"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 5, }, } @@ -1087,13 +1288,22 @@ func TestExecuteQuery(t *testing.T) { WithColCaveatName("caveat"), WithColCaveatContext("caveat_context"), WithColExpiration("expiration"), + WithColIntegrityHash("integrity_hash"), + WithColIntegrityKeyID("integrity_key_id"), + WithColIntegrityTimestamp("integrity_timestamp"), WithPlaceholderFormat(sq.Question), WithPaginationFilterType(filterType), WithColumnOptimization(ColumnOptimizationOptionStaticValues), WithNowFunction("NOW"), + WithIntegrityEnabled(tc.withIntegrityEnabled), WithExpirationDisabled(tc.withExpirationDisabled), ) filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) + filterer = filterer.WithFromSuffix(tc.fromSuffix) + if tc.limit > 0 { + filterer = filterer.limit(tc.limit) + } + ran := tc.run(filterer) var wasRun bool @@ -1107,6 +1317,46 @@ func TestExecuteQuery(t *testing.T) { require.Equal(t, tc.expectedArgs, args) require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) require.Equal(t, tc.expectedSkipExpiration, builder.SkipExpiration) + + // 6 standard columns for relationships: + // ns, object_id, relation, subject_ns, subject_object_id, subject_relation + expectedColCount := 6 - tc.expectedStaticColCount + if !tc.expectedSkipCaveats { + // caveat, caveat_context + expectedColCount += 2 + } + if !tc.expectedSkipExpiration && !tc.withExpirationDisabled { + // expiration + expectedColCount++ + } + if tc.withIntegrityEnabled { + // integrity_key_id, integrity_hash, integrity_timestamp + expectedColCount += 3 + } + + if tc.expectedStaticColCount == -1 { + // SELECT 1 + expectedColCount = 1 + } + + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName *string + var caveatCtx map[string]any + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + require.NoError(t, err) + require.Equal(t, expectedColCount, len(colsToSelect)) + return nil, nil }, } @@ -1118,3 +1368,39 @@ func TestExecuteQuery(t *testing.T) { }) } } + +func TestNewSchemaQueryFiltererWithStartingQuery(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + WithExpirationDisabled(true), + ) + + sql := sq.StatementBuilder.PlaceholderFormat(sq.AtP) + query := sql.Select("COUNT(*)").From("sometable") + filterer := NewSchemaQueryFiltererWithStartingQuery(*schema, query, 50) + filterer = filterer.MustFilterToResourceIDs([]string{"someid"}) + filterer = filterer.WithAdditionalFilter(func(original sq.SelectBuilder) sq.SelectBuilder { + return original.Where("somecoolclause") + }) + + sqlQuery, args, err := filterer.UnderlyingQueryBuilder().ToSql() + require.NoError(t, err) + + expectedSQL := "SELECT COUNT(*) FROM sometable WHERE object_id IN (@p1) AND somecoolclause" + expectedArgs := []any{"someid"} + require.Equal(t, expectedSQL, sqlQuery) + require.Equal(t, expectedArgs, args) +} diff --git a/internal/datastore/common/schema_options.go b/internal/datastore/common/zz_generated.schema_options.go similarity index 95% rename from internal/datastore/common/schema_options.go rename to internal/datastore/common/zz_generated.schema_options.go index fa7639776e..04b6088a36 100644 --- a/internal/datastore/common/schema_options.go +++ b/internal/datastore/common/zz_generated.schema_options.go @@ -48,7 +48,7 @@ func (s *SchemaInformation) ToOption() SchemaInformationOption { to.PlaceholderFormat = s.PlaceholderFormat to.NowFunction = s.NowFunction to.ColumnOptimization = s.ColumnOptimization - to.WithIntegrityColumns = s.WithIntegrityColumns + to.IntegrityEnabled = s.IntegrityEnabled to.ExpirationDisabled = s.ExpirationDisabled } } @@ -73,7 +73,7 @@ func (s SchemaInformation) DebugMap() map[string]any { debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) - debugMap["WithIntegrityColumns"] = helpers.DebugValue(s.WithIntegrityColumns, false) + debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false) debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false) return debugMap } @@ -213,10 +213,10 @@ func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaI } } -// WithWithIntegrityColumns returns an option that can set WithIntegrityColumns on a SchemaInformation -func WithWithIntegrityColumns(withIntegrityColumns bool) SchemaInformationOption { +// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation +func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption { return func(s *SchemaInformation) { - s.WithIntegrityColumns = withIntegrityColumns + s.IntegrityEnabled = integrityEnabled } } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 3f27d68d12..f4929967ed 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -223,7 +223,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas common.WithPlaceholderFormat(sq.Dollar), common.WithNowFunction("NOW"), common.WithColumnOptimization(config.columnOptimizationOption), - common.WithWithIntegrityColumns(config.withIntegrity), + common.WithIntegrityEnabled(config.withIntegrity), common.WithExpirationDisabled(config.expirationDisabled), )