Skip to content

Commit

Permalink
Add schema name to views for doltgres
Browse files Browse the repository at this point in the history
  • Loading branch information
tbantle22 committed Nov 13, 2024
1 parent ec0bc7c commit 5314c7f
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 31 deletions.
3 changes: 3 additions & 0 deletions go/libraries/doltcore/doltdb/system_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ const (
// SchemasTablesSqlModeCol is the name of the column that stores the SQL_MODE string used when this fragment
// was originally defined. Mode settings, such as ANSI_QUOTES, are needed to correctly parse the fragment.
SchemasTablesSqlModeCol = "sql_mode"
// SchemasTablesSchemaNameCol is the name of the column that stores the name of the schema that the fragment
// is part of. Used by Doltgres only.
SchemasTablesSchemaNameCol = "schema_name"
)

const (
Expand Down
1 change: 1 addition & 0 deletions go/libraries/doltcore/schema/reserved_tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ const (
DoltSchemasFragmentTag
DoltSchemasExtraTag
DoltSchemasSqlModeTag
DoltSchemasSchemaNameTag
)

// Tags for hidden columns in keyless rows
Expand Down
63 changes: 51 additions & 12 deletions go/libraries/doltcore/sqle/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,7 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
}
}

schemaName := db.schemaName
lwrViewName := strings.ToLower(viewName)
switch {
case strings.HasPrefix(lwrViewName, doltdb.DoltBlameViewPrefix):
Expand All @@ -1767,7 +1768,12 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
if err != nil {
return sql.ViewDefinition{}, false, err
}
return sql.ViewDefinition{Name: viewName, TextDefinition: blameViewTextDef, CreateViewStatement: fmt.Sprintf("CREATE VIEW `%s` AS %s", viewName, blameViewTextDef)}, true, nil
return sql.ViewDefinition{
Name: viewName,
SchemaName: db.schemaName,
TextDefinition: blameViewTextDef,
CreateViewStatement: fmt.Sprintf("CREATE VIEW `%s` AS %s", viewName, blameViewTextDef)},
true, nil
}

schemasTableName := getDoltSchemasTableName()
Expand Down Expand Up @@ -1809,22 +1815,22 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
}

if wrapper.backingTable == nil {
dbState.SessionCache().CacheViews(key, nil, db.schemaName)
dbState.SessionCache().CacheViews(key, nil, schemaName)
return sql.ViewDefinition{}, false, nil
}

views, viewDef, found, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, viewName)
views, viewDef, found, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, viewName, schemaName)
if err != nil {
return sql.ViewDefinition{}, false, err
}

// TODO: only cache views from a single schema here
dbState.SessionCache().CacheViews(key, views, db.schemaName)
dbState.SessionCache().CacheViews(key, views, schemaName)

return viewDef, found, nil
}

func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableDoltTable, viewName string) ([]sql.ViewDefinition, sql.ViewDefinition, bool, error) {
func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableDoltTable, viewName, schemaName string) ([]sql.ViewDefinition, sql.ViewDefinition, bool, error) {
fragments, err := getSchemaFragmentsOfType(ctx, tbl, viewFragment)
if err != nil {
return nil, sql.ViewDefinition{}, false, err
Expand All @@ -1843,14 +1849,15 @@ func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableD
}
} else {
views[i] = sql.ViewDefinition{
Name: fragments[i].name,
Name: fragments[i].name,
SchemaName: fragments[i].schemaName,
// TODO: need to define TextDefinition
CreateViewStatement: fragments[i].fragment,
SqlMode: fragment.sqlMode,
}
}

if strings.EqualFold(fragment.name, viewName) {
if strings.EqualFold(fragment.name, viewName) && strings.EqualFold(fragment.schemaName, schemaName) {
found = true
viewDef = views[i]
}
Expand Down Expand Up @@ -1878,7 +1885,7 @@ func (db Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
return nil, nil
}

views, _, _, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, "")
views, _, _, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, "", "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1974,7 +1981,7 @@ func (db Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinit
definition.Name,
definition.CreateStatement,
definition.CreatedAt,
fmt.Errorf("triggers `%s` already exists", definition.Name), //TODO: add a sql error and return that instead
fmt.Errorf("triggers `%s` already exists", definition.Name), // TODO: add a sql error and return that instead
)
}

Expand Down Expand Up @@ -2236,7 +2243,7 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
return err
}

_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name, db.schemaName)
if err != nil {
return err
}
Expand All @@ -2263,14 +2270,34 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin

sqlMode := sql.LoadSqlMode(ctx)

return inserter.Insert(ctx, sql.Row{fragType, name, definition, extraJSON, sqlMode.String()})
row := sql.Row{fragType, name, definition, extraJSON, sqlMode.String()}

// Include schema_name column for doltgres
if resolve.UseSearchPath && tbl.Schema().Contains(doltdb.SchemasTablesSchemaNameCol, tbl.Name()) {
if db.schemaName == "" {
root, err := db.GetRoot(ctx)
if err != nil {
return err
}
schemaName, err := resolve.FirstExistingSchemaOnSearchPath(ctx, root)
if err != nil {
return err
}
db.schemaName = schemaName
}

row = sql.Row{fragType, name, db.schemaName, definition, extraJSON, sqlMode.String()}
}

return inserter.Insert(ctx, row)
}

func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil {
return err
}

schemaName := db.schemaName
if resolve.UseSearchPath {
db.schemaName = "dolt"
}
Expand All @@ -2288,14 +2315,26 @@ func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name str
return missingErr
}

if resolve.UseSearchPath && schemaName == "" {
root, err := db.GetRoot(ctx)
if err != nil {
return err
}
schemaName, err = resolve.FirstExistingSchemaOnSearchPath(ctx, root)
if err != nil {
return err
}
}

tbl := swrapper.backingTable
row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name, schemaName)
if err != nil {
return err
}
if !exists {
return missingErr
}

deleter := tbl.Deleter(ctx)
err = deleter.Delete(ctx, row)
if err != nil {
Expand Down
87 changes: 68 additions & 19 deletions go/libraries/doltcore/sqle/schema_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,34 @@ func (st *SchemaTable) String() string {
return doltdb.GetSchemasTableName()
}

// GetSchemasSchema returns the schema of the dolt_schemas system table. This is used
// by Doltgres to update the dolt_schemas schema with an additional schema_name column.
var GetSchemasSchema = SchemaTableSchema

// func getDoltSchemasSchema(backingTable *WritableDoltTable) sql.Schema {
// if backingTable == nil {
// // No backing table; return a current schema.
// return GetSchemasSchema().Schema
// }

// if !backingTable.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.GetSchemasTableName()) {
// // No Extra column; return an ancient schema.
// return SchemaTableAncientSqlSchema()
// }

// if !backingTable.Schema().Contains(doltdb.SchemasTablesSqlModeCol, doltdb.GetSchemasTableName()) {
// // No SQL_MODE column; return an old schema.
// return SchemaTableV1SqlSchema()
// }

// return GetSchemasSchema().Schema
// }

func (st *SchemaTable) Schema() sql.Schema {
currentSchema := toSqlSchemaTableSchema(GetSchemasSchema())
if st.backingTable == nil {
// No backing table; return a current schema.
return SchemaTableSqlSchema().Schema
return currentSchema.Schema
}

if !st.backingTable.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.GetSchemasTableName()) {
Expand All @@ -71,7 +95,7 @@ func (st *SchemaTable) Schema() sql.Schema {
return SchemaTableV1SqlSchema()
}

return SchemaTableSqlSchema().Schema
return currentSchema.Schema
}

func (st *SchemaTable) Collation() sql.CollationID {
Expand Down Expand Up @@ -127,14 +151,22 @@ var _ sql.IndexAddressableTable = (*SchemaTable)(nil)
var _ sql.UpdatableTable = (*SchemaTable)(nil)
var _ WritableDoltTableWrapper = (*SchemaTable)(nil)

func SchemaTableSqlSchema() sql.PrimaryKeySchema {
sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), SchemaTableSchema())
func toSqlSchemaTableSchema(sch schema.Schema) sql.PrimaryKeySchema {
sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), sch)
if err != nil {
panic(err) // should never happen
}
return sqlSchema
}

// func SchemaTableSqlSchema() sql.PrimaryKeySchema {
// sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), SchemaTableSchema())
// if err != nil {
// panic(err) // should never happen
// }
// return sqlSchema
// }

func mustNewColWithTypeInfo(name string, tag uint64, typeInfo typeinfo.TypeInfo, partOfPK bool, defaultVal string, autoIncrement bool, comment string, constraints ...schema.ColConstraint) schema.Column {
col, err := schema.NewColumnWithTypeInfo(name, tag, typeInfo, partOfPK, defaultVal, autoIncrement, comment, constraints...)
if err != nil {
Expand Down Expand Up @@ -250,7 +282,7 @@ func getOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *Writabl
}

// Create new empty table
err = db.createDoltTable(ctx, tname, root, SchemaTableSchema())
err = db.createDoltTable(ctx, tname, root, GetSchemasSchema())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -367,8 +399,8 @@ func migrateOldSchemasTableToNew(ctx *sql.Context, db Database, schemasTable *Wr
}

// fragFromSchemasTable returns the row with the given schema fragment if it exists.
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType string, name string) (r sql.Row, found bool, rerr error) {
fragType, name = strings.ToLower(fragType), strings.ToLower(name)
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType, name, schemaName string) (r sql.Row, found bool, rerr error) {
fragType, name, schemaName = strings.ToLower(fragType), strings.ToLower(name), strings.ToLower(schemaName)

// This performs a full table scan in the worst case, but it's only used when adding or dropping a trigger or view
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
Expand All @@ -387,6 +419,7 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
// need to get the column indexes from the current schema
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
schemaNameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSchemaNameCol)

for {
sqlRow, err := iter.Next(ctx)
Expand All @@ -397,8 +430,13 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
return nil, false, err
}

sqlRowSchemaName := ""
if schemaNameIdx >= 0 {
sqlRowSchemaName = sqlRow[schemaNameIdx].(string)
}

// These columns are case insensitive, make sure to do a case-insensitive comparison
if strings.EqualFold(sqlRow[typeIdx].(string), fragType) && strings.EqualFold(sqlRow[nameIdx].(string), name) {
if strings.EqualFold(sqlRow[typeIdx].(string), fragType) && strings.EqualFold(sqlRow[nameIdx].(string), name) && strings.EqualFold(sqlRowSchemaName, schemaName) {
return sqlRow, true, nil
}
}
Expand All @@ -407,9 +445,10 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
}

type schemaFragment struct {
name string
fragment string
created time.Time
name string
schemaName string
fragment string
created time.Time
// sqlMode indicates the SQL_MODE that was used when this schema fragment was initially parsed. SQL_MODE settings
// such as ANSI_QUOTES control customized parsing behavior needed for some schema fragments.
sqlMode string
Expand All @@ -424,6 +463,7 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
// need to get the column indexes from the current schema
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
schemaNameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSchemaNameCol)
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
fragmentIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
extraIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
Expand Down Expand Up @@ -463,13 +503,21 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
sqlModeString = defaultSqlMode
}

schemaNameString := ""
if schemaNameIdx >= 0 {
if s, ok := sqlRow[schemaNameIdx].(string); ok {
schemaNameString = s
}
}

// For older tables, use 1 as the trigger creation time
if extraIdx < 0 || sqlRow[extraIdx] == nil {
frags = append(frags, schemaFragment{
name: sqlRow[nameIdx].(string),
fragment: sqlRow[fragmentIdx].(string),
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
sqlMode: sqlModeString,
name: sqlRow[nameIdx].(string),
schemaName: schemaNameString,
fragment: sqlRow[fragmentIdx].(string),
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
sqlMode: sqlModeString,
})
continue
}
Expand All @@ -481,10 +529,11 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
}

frags = append(frags, schemaFragment{
name: sqlRow[nameIdx].(string),
fragment: sqlRow[fragmentIdx].(string),
created: time.Unix(createdTime, 0).UTC(),
sqlMode: sqlModeString,
name: sqlRow[nameIdx].(string),
schemaName: schemaNameString,
fragment: sqlRow[fragmentIdx].(string),
created: time.Unix(createdTime, 0).UTC(),
sqlMode: sqlModeString,
})
}

Expand Down

0 comments on commit 5314c7f

Please sign in to comment.